라이트닝의 어느 부분에 추가해야 하는가?
관련 이슈
Clarify where to apply `torch.compile` in Fabric and add more tests · Issue #17250 · Lightning-AI/pytorch-lightning
Description & Motivation Users currently ask questions where
torch.compile should be applied in Fabric. There are two main choices: Before calling fabric.setup() After calling fabric.setup() Which ...
github.com
라이트닝 독스, 컴파일의 한계점
Speed up models by compiling them — PyTorch Lightning 2.4.0 documentation
Speed up models by compiling them Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs. This guide shows you how to apply torch.compile correctly in your code. Apply torch.compile to your Lightning
lightning.ai
라이트닝 독스, DDP 시 컴파일 된 모델 재적용 하는 법
Speed up models by compiling them — lightning 2.4.0 documentation
Speed up models by compiling them Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs. This guide shows you how to apply torch.compile correctly in your code. Apply torch.compile to your model Compi
lightning.ai
`torch.compile`, The Missing Manual
torch.compile, the missing manual
torch.compile, the missing manual You are here because you want to use torch.compile to make your PyTorch model run faster. torch.compile is a complex and relatively new piece of software, and so you are likely to have growing pains. This manual is all abo
docs.google.com
결론: 라이트닝 모듈의 `__init__` 안에서
라이트닝 독스 윗쪽에서는 라이트닝 모듈 전체를 torch.compile 하라고 하는데 실제론 warning이 엄청 뜬다. 허깅페이스 backbone 등 컴파일에 친화적임이 보장된 모듈만 따로 컴파일 해줘야 한다. 그리고 DDP 환경에서는 라이트닝 모듈의 `__init__` 안에서 진행 돼야 한다. DDP는 GPU 각각을 서브 프로세스 하나씩 생성해서 관리하는데 컴파일을 프로세스 밖에서 하면 문제가 되나보다. 나는 gradient update가 안 되는 문제를 겪었다. 모듈 밖에서 이미 instance 화 된 라이트닝 모듈의 backbone만 컴파일 해줬다. 그랬더니 로스가 전혀 안 떨어진다. 컴파일 위치를 위와 같이 바꾸니 해결됐다.
해결 중
딥스피드와 어떻게 결합하는가? (그냥 합치면 오류 엄청 뜸)
How to use torch.compile in DeepSpeed? · Issue #3375 · microsoft/DeepSpeed
If I want to use the new feature of Pytorch2.0——torch.compile, what should I do? Where should I put the following code or just pass a command line parameter? model = torch.compile(model)
github.com