Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.functional.scaled_dot_product_attention returns NaN values after backward pass. #126654

Open
daniel-padban opened this issue May 19, 2024 · 5 comments
Labels
module: numerical-stability Problems related to numerical stability of operations needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@daniel-padban
Copy link

daniel-padban commented May 19, 2024

馃悰 Describe the bug

When using torch.nn.functional.scaled_dot_product_attention with autograd a tensor filled with NaN values are returned after a few backward passes. Using torch.autograd.set_detect_anomaly(True) returned this error message for loss.backward():

RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

This issue has also been brought up by another user in the forum.

Edit:
Reproducing code with attention class and forward function:

class QKVAttention(nn.Module):
    def __init__(self,hidden_size,attn_mask=None,):
        super().__init__()
        self.hidden = hidden_size
        self.Q_linear = nn.Linear(hidden_size,hidden_size)
        self.K_linear = nn.Linear(hidden_size,hidden_size)
        self.V_linear = nn.Linear(hidden_size,hidden_size)
        self.attn_mask = attn_mask.unsqueeze(-2)

    def forward(self,x):
        x = x.unsqueeze(-2)
        
        Q = self.Q_linear(x)
        K = self.K_linear(x)
        V = self.V_linear(x)
      
        #scaled_dot_product_attention() is beta and subject to change as of 2024-05-17
        #problem was produced here:
        attn_out = nn.functional.scaled_dot_product_attention(Q,K,V,attn_mask=self.attn_mask,dropout_p=0.1,is_causal=False)
        
        return attn_out

class MainNet(nn.module):                             
    #init function was skipped
     def forward(self,input):
          h0 = torch.zeros(self.num_layers, self.batch_size,self.hidden_size, requires_grad=True).to(device)
          c0 = torch.zeros(self.num_layers,self.batch_size,self.hidden_size,requires_grad=True).to(device)
          
          packed_lstm_out,(_,_) = self.lstm1(input,(h0,c0))
          
          padded_lstm_out,_ = rnn.pad_packed_sequence(packed_lstm_out,padding_value=44444,batch_first=True)
          
          #creates a row-mask for the attention layer:
          positive_mask = padded_lstm_out == 44444
          positive_row_mask = torch.any(positive_mask,dim=-1,keepdim=True)
          attn_mask = ~positive_row_mask

          attention_layer = QKVAttention(self.hidden_size,attn_mask=attn_mask)
          context = attention_layer(padded_lstm_out)
          squuezed_context = context.squeeze(-2)

          final_output = self.fco(squuezed_context[:,-1,:])
        
          return final_output

Training loop:

optimizer  = torch.optim.AdamW(params=lstm_model.parameters(recurse=True),weight_decay=1e-2,lr=1e-4)
loss_fn = torch.nn.MSELoss()

def training_loop():
    model.train()
    torch.autograd.set_detect_anomaly(True)
    for batch_n, (X,y) in enumerate(train_dataloader):
        pred = model(X)
        
        #calculate loss & backward pass
        loss = loss_fn(pred,y)
        loss.backward()
        
        #adjust params
        optimizer.step()
        optimizer.zero_grad()
        
        #print loss
        if batch_n % report_freq == 0:
            current_values = batch_n*batch_size + len(X)
            print(f"Train loss: {running_loss:>7f} [{current_values:>5d}/{size:>5d}]")

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: version 3.24.4
Libc version: N/A

Python version: 3.12.2 (main, Mar 1 2024, 19:22:10) [Clang 15.0.0 (clang-1500.1.0.2.5)] (64-bit runtime)
Python platform: macOS-14.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0
[pip3] torch-tb-profiler==0.4.3
[pip3] torchviz==0.0.2
[conda] Could not collect

cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

@ZailiWang
Copy link
Contributor

Hi, would you provide a reproducing code? Thanks.

@drisspg drisspg added module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: transformer/mha Issues related to Transformers and MultiheadAttention labels May 20, 2024
@daniel-padban
Copy link
Author

Hi, I have now edited the issue and added the code.

@drisspg
Copy link
Contributor

drisspg commented May 20, 2024

There is still not enough info to repro, can you provide a standalone script that can be run with something like python script.py in order to repro? Specifically, there is no input data so still dont know how to repro

@drisspg drisspg added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label May 20, 2024
@daniel-padban
Copy link
Author

Hi, I was unable to reproduce the issue with the nan values, but I did find that the loss would explode, for example this: 1111107968.0, which could help in investigating the problem. I ran the code below to get the exploding loss, the data was downloaded from Nasdaq ALTR.

Code:

import torch
from torch.nn.functional import scaled_dot_product_attention
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn import MSELoss
from torch.optim import AdamW
from torch.utils.data import Dataset,DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
import pandas as pd


data = pd.read_csv('nc_data_folder/ALTR_data.csv')
data.drop(columns='Date',inplace=True)

class DatasSet(Dataset):
    def __init__(self,data,target_col,price_cols,columns,seq_len):
        self.price_cols = price_cols
        self.target_col = target_col
        self.columns = columns
        self.seq_len = seq_len
        scaled_data = self.scaling(data=data)
        self.X = torch.tensor(scaled_data.loc[:,scaled_data.columns!=target_col].values,dtype=torch.float,requires_grad=True)
        self.y = torch.tensor(scaled_data.loc[:,target_col].values,dtype=torch.float,requires_grad=True)

    def scaling(self,data):
        for col in self.price_cols:
            data[col] = data[col].str.replace(pat=r'^\D+',repl='',regex=True).astype(float)
        imputer = SimpleImputer(strategy='mean',)
        scaler = StandardScaler()
        pipe = Pipeline([('standardScaler',scaler),('imputer',imputer),])
        scaled_data = pipe.fit_transform(data)
        standard_scaled_df = pd.DataFrame(scaled_data,columns=self.columns)

        return standard_scaled_df
    
    def __len__(self):
        x_size = self.X.size(0)
        return x_size
    
    def __getitem__(self, idx):
        
        end_idx = min(idx+self.seq_len,len(self))

        #if end_idx-1-idx > len(self)
        
        X = self.X[idx:end_idx,:]

        y = self.y[idx:end_idx]

        return X,y

def seq_collate(batch):
    #split x and y batch
    xx,yy = zip(*batch)
    
    #sequence length for each item (sequence) in batch
    x_lens = [len(x) for x in xx ]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(sequences=xx,batch_first=True,padding_value=333333)
    yy_pad = pad_sequence(sequences=yy,batch_first=True,padding_value=333333)
    xx_mask = (xx_pad != 333333)
    yy_mask = (yy_pad != 333333)

    return xx_pad, yy_pad, xx_mask, yy_mask

class QKVAttention(nn.Module):
    def __init__(self,input_size):
        super().__init__()
        self.Q_linear = nn.Linear(input_size,input_size,bias=False)
        self.K_linear = nn.Linear(input_size,input_size,bias=False)
        self.V_linear = nn.Linear(input_size,input_size,bias=False)

    def forward(self,x):
        
        Q = self.Q_linear(x)
        K = self.K_linear(x)
        V = self.V_linear(x)
      
        #scaled_dot_product_attention() is beta and subject to change - 2024-05-17
        attn_out = scaled_dot_product_attention(Q,K,V,dropout_p=0.01,is_causal=True)
        
        return attn_out
    
class MainNet(nn.Module):
    def __init__(self,input_size, hidden_size,num_layers, batch_size,):
        super().__init__()
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True,)
        self.attention = QKVAttention(input_size=hidden_size)
        self.fco = nn.Linear(in_features=hidden_size,out_features=seq_len,)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, self.batch_size,self.hidden_size, requires_grad=True)
        c0 = torch.zeros(self.num_layers,self.batch_size,self.hidden_size,requires_grad=True)

        lstm_out,_ = self.lstm(x)
        attention = self.attention(lstm_out)
        final_output = self.fco(attention[:,-1,:])
        return final_output


def training_loop(dataloader,model,loss_fn,optimizer, batch_size,report_freq=10,):
    size = len(dataloader.dataset)

    model.train()
    torch.autograd.set_detect_anomaly(True)
    for batch_n, (X,y,_,y_mask) in enumerate(dataloader):
        pred = model(X)

        #calculate optim loss
        loss = loss_fn(pred,y)
        loss.backward()
        
        #adjust params
        optimizer.step()
        optimizer.zero_grad()
        
        #get loss for each pass
        running_loss = loss.item()

        if batch_n % report_freq == 0:
            current_values = batch_n*batch_size + len(X)
            print(f"Train loss: {running_loss:>7f} [{current_values:>5d}/{size:>5d}]")

price_cols = ['Close/Last','Open','High','Low']
columns = ['Close/Last','Volume','Open','High','Low']
seq_len = 20
batch_size = 10
dataset = DatasSet(data=data,target_col='Close/Last',price_cols=price_cols,columns=columns,seq_len=seq_len)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,collate_fn=seq_collate,drop_last=True)
model = MainNet(4,50,2,batch_size)

epochs = 50
loss_fn = MSELoss()
optimizer = AdamW(params=model.parameters(recurse=True),lr=1e-4,weight_decay=1e-2)

for epoch in range(epochs):
    training_loop(dataloader=dataloader,model=model,loss_fn=loss_fn,optimizer=optimizer,batch_size=batch_size)
    print(f'---------- Epoch {epoch} ----------')

@drisspg
Copy link
Contributor

drisspg commented May 20, 2024

can you try wrapping the call to sdpa with this context manager:

def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):

and only enabling the 'SDPBackend.MATH' backend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numerical-stability Problems related to numerical stability of operations needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants