Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

TORCH_CHECK used within torch.compile does not throw legible errors #126691

Open
lezcano opened this issue May 20, 2024 · 6 comments
Open

TORCH_CHECK used within torch.compile does not throw legible errors #126691

lezcano opened this issue May 20, 2024 · 6 comments
Assignees
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lezcano
Copy link
Collaborator

lezcano commented May 20, 2024

馃悰 Describe the bug

At best, you get a

terminate called after throwing an instance of 'c10::Error'
terminate called after throwing an instance of 'c10::Error'
terminate called after throwing an instance of 'c10::Error'
terminate called recursively
[1]    2964411 IOT instruction (core dumped)

At worst, you just get the "terminate called recursively" with not mention to c10::Error

To repro, you can perform any out of bounds error. I got it with

import torch

@torch.compile
def copy(a):
    b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device=a.device)
    return a[b]
x = torch.rand(1024, 128, device="cpu")
copy(x)

when testing #114471 (may not repro in master).

Versions

#114471

@lezcano lezcano added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: cpu inductor CPU Inductor issues for Intel team to triage labels May 20, 2024
@lezcano
Copy link
Collaborator Author

lezcano commented May 20, 2024

cc @jgong5

@lezcano
Copy link
Collaborator Author

lezcano commented May 21, 2024

It doesn't happen for every input. The following input throws a nice error:

@torch.compile
def fn(x, y):
    b = torch.arange(x.size(0) - 1, device=x.device) + 3
    return x[b] + y[b]
x = torch.rand(1024, 128, device="cpu")
y = torch.rand(1024, 128, device="cpu")
fn(x, y)

Note that the indexing error in the OP is much more egregious than the one in this second example.

@jgong5
Copy link
Collaborator

jgong5 commented May 21, 2024

when testing #114471 (may not repro in master).

How to reproduce it? It cannot be repro on master.

@lezcano
Copy link
Collaborator Author

lezcano commented May 21, 2024

Patch in that PR. There are quite a few issues in master when it comes to issuing device_asserts. See that in the repro in the OP we don't even generate a TORCH_CHECK and rather we just read out of bounds.

@zhuhaozhe
Copy link
Collaborator

zhuhaozhe commented Jun 4, 2024

I have compared the difference with ATen ops, we need to change 2 things to throw as legible errors as aten does here:

  • Inductor path dose not translate c++ exception into python like aten
  • Inductor path do not have logic to catch exception inside omp paralell region, while at::parallel_for can catch and re throw it. The difference behavior of difference inputs is also because 1 will encounter omp parallel and one is not.

I will submit a PR for it.

@zhuhaozhe
Copy link
Collaborator

Submitted a PR here. #127868

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants