-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.nn.functional.scaled_dot_product_attention returns NaN values after backward pass. #126654
Comments
Hi, would you provide a reproducing code? Thanks. |
Hi, I have now edited the issue and added the code. |
There is still not enough info to repro, can you provide a standalone script that can be run with something like |
Hi, I was unable to reproduce the issue with the nan values, but I did find that the loss would explode, for example this: 1111107968.0, which could help in investigating the problem. I ran the code below to get the exploding loss, the data was downloaded from Nasdaq ALTR. Code:
|
can you try wrapping the call to sdpa with this context manager: pytorch/torch/nn/attention/__init__.py Line 70 in cd3a71f
and only enabling the 'SDPBackend.MATH' backend |
馃悰 Describe the bug
When using
torch.nn.functional.scaled_dot_product_attention
with autograd a tensor filled with NaN values are returned after a few backward passes.Using torch.autograd.set_detect_anomaly(True)
returned this error message forloss.backward()
:RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
This issue has also been brought up by another user in the forum.
Edit:
Reproducing code with attention class and forward function:
Training loop:
Versions
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.24.4
Libc version: N/A
Python version: 3.12.2 (main, Mar 1 2024, 19:22:10) [Clang 15.0.0 (clang-1500.1.0.2.5)] (64-bit runtime)
Python platform: macOS-14.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchviz==0.0.2
[conda] Could not collect
cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki
The text was updated successfully, but these errors were encountered: