AX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。
Tunix是什么
这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:
- 监督微调(Supervised Fine-Tuning)
- 强化学习(Reinforcement Learning)
- 知识蒸馏(Knowledge Distillation)
现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。
https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97