So a lot has happened since last week. Today is a holiday, and besides the nice lunch meetup with an old acquaintance, my focus today is just to make progress on this side project.
personal reflections
it’s cool that so far this year, I’ve been making the most (non-work) GitHub activity of any year in the past, at least in terms of rate. this is pretty good! I wonder if I’ll be able to keep that up.
but that’s a good reminder that, no, I’m not getting lazier or less inspired or something with software work, even though it seems that way. I’m just pushing myself to do bigger things!
it matters -A
when you’re small, your elementary school looks so big, and then you come back as an adult and marvel at how small everything was. this is kind of like that
relevant to jax-js plans
um so we had Chenyu from tinygrad come over to NYSRG and I think after reading the codebase and also taking notes on related cuda things
i understand the picture of compiling operations into kernels a lot better now
the rewrite rules / lazy pattern matchers is just a less PL jargon-infused way (or should I saw, less PL-aware) of talking about staging, like what JAX does with HLO/XLA
you can get a long way with pretty simple kernels and just a couple hand-rolled heuristics is my takeaway from the tinygrad paper
in retrospect this should be pretty obvious. like, automatic heuristics should certainly at least be better than a static library of a couple compiled kernels. it’s smaller and more flexible with low development resources
and gpus can’t be that complicated. there are memory hierarchies, but even complex problems tend to have fairly parsimonious solutions
this means that I am pretty confident (overconfident??) in being able to get rid of the dependency on tfjs-core at some point in the future
which is huge, since then I’m not limited to a couple dtypes and can also optimize any operations of my choice, and extend the project arbitrarily to support even more operations or algorithms to achieve numpy API-compatibility
you want a QR decomposition from numpy.linalg.qr()? sure, have it
$ TC=0 DEBUG=4 python3 test.py
# ... stuff
0: (64, 32, 8, 16, 1, 4, 4, 1) float.ptr(4194304) (65536, 64, 8192, 4, 0, 1, 2048, 0)
1: (64, 32, 8, 16, 512, 4, 4, 4) float.ptr(4194304) (0, 64, 0, 4, 8192, 1, 0, 2048)
2: (64, 32, 8, 16, 512, 4, 4, 4) float.ptr(4194304) (65536, 0, 8192, 0, 4, 0, 2048, 1)
[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=1, arg=16)]
#include <metal_stdlib>
using namespace metal;
kernel void r_64_32_8_16_512_4_4_4(device float* data0, device float* data1, device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int gidx0 = gid.x; /* 32 */
int gidx1 = gid.y; /* 64 */
int lidx0 = lid.x; /* 8 */
int lidx1 = lid.y; /* 16 */
int alu0 = (gidx0<<6);
int alu1 = (gidx1<<16);
int alu2 = (lidx0<<13);
int alu3 = (lidx1<<2);
float acc0 = 0.0f;
float acc1 = 0.0f;
float acc2 = 0.0f;
float acc3 = 0.0f;
float acc4 = 0.0f;
float acc5 = 0.0f;
float acc6 = 0.0f;
float acc7 = 0.0f;
float acc8 = 0.0f;
float acc9 = 0.0f;
float acc10 = 0.0f;
float acc11 = 0.0f;
float acc12 = 0.0f;
float acc13 = 0.0f;
float acc14 = 0.0f;
float acc15 = 0.0f;
for (int ridx0 = 0; ridx0 < 512; ridx0++) {
int alu4 = (alu1+alu2+(ridx0<<2));
float4 val0 = *((device float4*)((data1+alu4)));
float4 val1 = *((device float4*)((data1+(alu4+2048))));
float4 val2 = *((device float4*)((data1+(alu4+4096))));
float4 val3 = *((device float4*)((data1+(alu4+6144))));
int alu5 = (alu0+alu3+(ridx0<<13));
float4 val4 = *((device float4*)((data2+alu5)));
float4 val5 = *((device float4*)((data2+(alu5+2048))));
float4 val6 = *((device float4*)((data2+(alu5+4096))));
float4 val7 = *((device float4*)((data2+(alu5+6144))));
acc0 = (acc0+(val0.x*val4.x)+(val0.y*val5.x)+(val0.z*val6.x)+(val0.w*val7.x));
acc1 = (acc1+(val1.x*val4.x)+(val1.y*val5.x)+(val1.z*val6.x)+(val1.w*val7.x));
acc2 = (acc2+(val2.x*val4.x)+(val2.y*val5.x)+(val2.z*val6.x)+(val2.w*val7.x));
acc3 = (acc3+(val3.x*val4.x)+(val3.y*val5.x)+(val3.z*val6.x)+(val3.w*val7.x));
acc4 = (acc4+(val0.x*val4.y)+(val0.y*val5.y)+(val0.z*val6.y)+(val0.w*val7.y));
acc5 = (acc5+(val1.x*val4.y)+(val1.y*val5.y)+(val1.z*val6.y)+(val1.w*val7.y));
acc6 = (acc6+(val2.x*val4.y)+(val2.y*val5.y)+(val2.z*val6.y)+(val2.w*val7.y));
acc7 = (acc7+(val3.x*val4.y)+(val3.y*val5.y)+(val3.z*val6.y)+(val3.w*val7.y));
acc8 = (acc8+(val0.x*val4.z)+(val0.y*val5.z)+(val0.z*val6.z)+(val0.w*val7.z));
acc9 = (acc9+(val1.x*val4.z)+(val1.y*val5.z)+(val1.z*val6.z)+(val1.w*val7.z));
acc10 = (acc10+(val2.x*val4.z)+(val2.y*val5.z)+(val2.z*val6.z)+(val2.w*val7.z));
acc11 = (acc11+(val3.x*val4.z)+(val3.y*val5.z)+(val3.z*val6.z)+(val3.w*val7.z));
acc12 = (acc12+(val0.x*val4.w)+(val0.y*val5.w)+(val0.z*val6.w)+(val0.w*val7.w));
acc13 = (acc13+(val1.x*val4.w)+(val1.y*val5.w)+(val1.z*val6.w)+(val1.w*val7.w));
acc14 = (acc14+(val2.x*val4.w)+(val2.y*val5.w)+(val2.z*val6.w)+(val2.w*val7.w));
acc15 = (acc15+(val3.x*val4.w)+(val3.y*val5.w)+(val3.z*val6.w)+(val3.w*val7.w));
}
int alu23 = (alu0+alu1+alu2+alu3);
*((device float4*)((data0+alu23))) = float4(acc0,acc4,acc8,acc12);
*((device float4*)((data0+(alu23+2048)))) = float4(acc1,acc5,acc9,acc13);
*((device float4*)((data0+(alu23+4096)))) = float4(acc2,acc6,acc10,acc14);
*((device float4*)((data0+(alu23+6144)))) = float4(acc3,acc7,acc11,acc15);
}
*** METAL 9 r_64_32_8_16_512_4_4_4 arg 3 mem 0.05 GB tm 19.86ms/ 22.58ms ( 864.86 GFLOPS 2.5|865.7 GB/s) ['__matmul__']
but right now the milestones still look like:
- [ ] It works!
- [ ] Demos: Navier-Stokes, neural networks, statistics
- [ ] We figure out the `dispose()` / linear types stuff
- [ ] Device switching with `.to()` between webgl/webgpu/cpu/wasm
- [ ] First custom kernel
- [ ] numpy/jax API compatibility table
- [ ] Convert Jaxprs into a tree data structure
- [ ] Pattern matchers for kernel fusion
- [ ] Kernel codegen, or synthesis
in particular I think the pattern matchers, scheduling, and codegen components (equivalent of ExecItem in tinygrad) will probably end up fitting into the equivalent of the `xla_call` operation in JAX. so we’ll have two separate parts of the codebase, one for compilation and one for non-jitted code.
this sounds kind of weird at first, but I think it’s the right choice given the design tradeoffs we’re making. we want it to be fast, but we don’t need to squeeze out every drop of performance — after all, we don’t even know what hardware we’re running on since it’s a javascript in-browser library.
the other advantage of jitting this is that we can auto-manage memory (er, we have to predict static memory patterns anyway, so we get this for free) and that’s important given that javascript has no reliable GC dispose hook (destructor)
anyway this seems pretty solid
development
tests continue to pass and reveal their utility over time. also vitest’s inline snapshot testing is quite fast & awesome.
anyway, it’s 8 PM right now, here’s what we got from today
git --no-pager diff --stat "@{1 day ago}"
README.md | 3 +
src/core.ts | 501 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
src/index.ts | 13 ++-
src/pprint.test.ts | 68 +++++++++++++
src/pprint.ts | 57 +++++++++++
src/utils.ts | 2 +-
test/tracing.test.ts | 48 +++++++++
7 files changed, 688 insertions(+), 4 deletions(-)
basically just finished implementing jaxpr logic and tracing. I understand how xla_call works as well now, the jaxpr is placed into the parameters and it composes in some interesting ways. mental exercises:
what happens when you jit() a jit()
what happens when you jvp() a jit()
what happens when you jit() a jvp()
what happens when you makeJaxpr() a jit()
I also understand (I think?) this gem of a quote laden with PL terminology, lmao
There are two options for how to handle higher-order primitives. Each requires a different approach to tracing and engenders different tradeoffs:
On-the-fly processing, where
bind
takes a Python callable as an argument. We defer forming a jaxpr until as late as possible, namely until we’re running the final interpreter at the bottom of the interpreter stack. That way we can swap aJaxprTrace
in at the bottom of the interpreter stack and thus stage out rather than execute all primitive operations. With this approach, transformations in the stack get applied as we execute the Python callable as usual. This approach can be very tricky to implement, but it’s as general as possible because it allows higher-order primitives not to raise the abstraction level of their arguments and thus allows data-dependent Python control flow. We refer to this approach as using a “final-style higher-order primitive” employing the discharge-at-tracing-time “final-style transformations” we’ve used so far.Staged processing, where
bind
takes a jaxpr as an argument. Before we callbind
, in the primitive wrapper we can just usemake_jaxpr
to form a jaxpr up-front and be done with the Python callable entirely. In this case,make_jaxpr
puts itsJaxprTrace
at the top of the interpreter stack, and no transformations lower in the stack, which might enter via closed-over Tracers, are applied to the Python callable as we trace it. (Transformations applied within the Python callable are applied as usual, being added to the stack above the JaxprTrace.) Instead, the transformations lower in the stack are later applied to the call primitive, and the call primitive’s rules must then transform the jaxpr itself. Because we trace to a jaxpr up-front, this approach can’t support data-dependent Python control flow, but it is more straightforward to implement. We refer to this kind of higher-order primitive as an “initial-style higher-order primitive”, and say that its jaxpr-processing transformation rules are “initial-style transformation rules.”The latter approach fits for
jit
because we don’t need to support data-dependent Python control flow in the user-provided Python callable, as the whole purpose ofjit
is to stage computation out of Python to be executed by XLA. (In contrast,custom_jvp
is a higher-order primitive in which we want to support data-dependent Python control flow.)Historically, we started using the “initial-style” and “final-style” terminology after reading the typed tagless final interpreters paper, and jokingly referring to JAX as an implementation of “untyped tagful final interpreters.” We don’t claim to carry over (or understand) any deep meaning behind these terms; we loosely use “initial style” to mean “build an AST and then transform it”, and we use “final style” to mean “transform as we trace.” But it’s just imprecise yet sticky jargon.
next up is linearize / vjp, which I’m excited about. finally getting a glimpse into conal elliott’s mind
anyway, we’re getting there a bit at a time!
concluding
i think side projects are hard, but I’m reminded that like a lot of things in life, you just need to make a routine. discipline is hard, but routines are easy
if you write 200 lines of code each day for a month, you’ll have written 6000 lines of code in that month
that’s pretty substantial. like sshx.io-sized! the difference is that sshx.io took nearly 2 years, lol — but to be fair, you’re oftentimes debugging or removing code too. in any case, routines make everything easier, whether it’s organizing meetups like nysrg or running or learning to play an instrument, and let’s try to find one.
it can be pretty hard to work on something so difficult by yourself, but at the same time, it’s true that I find it really cool, and i enjoy this kind of creative work :)