Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamo fails to export data-dependent Tensor.tolist #126685

Open
a-gardner1 opened this issue May 20, 2024 · 5 comments
Open

Dynamo fails to export data-dependent Tensor.tolist #126685

a-gardner1 opened this issue May 20, 2024 · 5 comments
Labels
module: dynamic shapes module: sparse Related to torch.sparse oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@a-gardner1
Copy link
Contributor

a-gardner1 commented May 20, 2024

馃悰 Describe the bug

I am running into a bug potentially related to #114483 where dynamo fails to export Tensor.tolist with data-dependent size. A minimal repro is below:

import torch
onnx_program = torch.onnx.dynamo_export(
    lambda x: torch.unique(x).tolist(),
    torch.arange(10),
    export_options=torch.onnx.ExportOptions(
        dynamic_shapes=True,
        diagnostic_options=torch.onnx.DiagnosticOptions(
            verbosity_level=logging.DEBUG)))
Traceback
.../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
E0518 18:50:33.363000 140001975370368 torch/fx/experimental/recording.py:280] [0/0] failed while running evaluate_expr(*(u0, None), **{'fx_node': None})
Traceback (most recent call last):
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1503, in dynamo_export
    ).export()
      ^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 213, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
                                ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1268, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 168, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.localpython-3.11.6/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
          ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1845, in CALL
    self.call_function(fn, args, kwargs)
  File ".../lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 667, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 449, in call_method
    result = handler_method(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 679, in method_tolist
    out = tolist(tensor, self.as_proxy())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 671, in tolist
    return [wrap(val, sub_proxy[i]) for i, val in enumerate(tensor)]
                                                  ^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_tensor.py", line 1055, in __iter__
    return iter(self.unbind(0))
                ^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 650, in __torch_dispatch__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_ops.py", line 630, in __call__
    return self_._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 973, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1362, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1065, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1602, in _dispatch_impl
    return decomposition_table[func](*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_refs/__init__.py", line 3913, in unbind
    torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 973, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1362, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1065, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1606, in _dispatch_impl
    r = func.decompose(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/_ops.py", line 667, in decompose
    return self._op_dk(dk, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/fx/experimental/sym_node.py", line 355, in guard_int
    r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/fx/experimental/recording.py", line 244, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4773, in evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: u0)

Potential framework code culprit (scroll up for full backtrace):
  File ".../lib/python3.11/site-packages/torch/_ops.py", line 667, in decompose
    return self._op_dk(dk, *args, **kwargs)

For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

User Stack (most recent call last):
  (snipped, see stack below for prefix)
  File "...", line 719, in <lambda>
    lambda x: torch.unique(x).tolist(),

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

from user code:
   File "...", line 719, in <lambda>
    lambda x: torch.unique(x).tolist(),

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File ".../.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "...", line 718, in <module>
    onnx_program = torch.onnx.dynamo_export(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/torch/onnx/_internal/exporter.py", line 1514, in dynamo_export
    raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

The offending function appears to ultimately be aten.tensor_split.sections. I tried to add a decomposition (below) for it alongside the one for aten.tensor_split.tensor_indices_or_sections, but I ultimately get blocked by a similar error as above.

`torch.tensor_split.sections` decomposition
@register_decomposition(aten.tensor_split.sections)
# @aten.tensor_split.sections.py_impl(
#     DispatchKey.CompositeImplicitAutograd
# )
def tensor_split_sections(
    self: Tensor,
    sections: int,
    dim: int = 0,
) -> Tuple[Tensor, ...]:
    dim_size = self.shape[dim]
    k = dim_size // sections
    m = dim_size % sections
    # Avoid importing sympy at a module level
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, expect_true

    if guard_size_oblivious(m == 0):
        # extra_long_length = 0
        long_sections = tuple()
    else:
        # extra_long_length = (k+1)*m
        # long_sections = torch.split(self.narrow(dim, 0, extra_long_length), k + 1, dim)
        i = 0
        long_sections = []
        while guard_size_oblivious(i < m):
            long_sections.append(self.narrow(dim, (k+1)*i, k+1))
            i += 1
        long_sections = long_sections
    i = 0
    normal_sections = []
    expect_true(m < sections)
    while guard_size_oblivious(i < (sections - m)):
        normal_sections.append(self.narrow(dim, k*i, k))
        normal_sections = normal_sections
        i += 1
    # normal_sections = torch.split(
    #     self.narrow(dim, extra_long_length, dim_size - extra_long_length),
    #     k,
    #     dim
    # )
    return tuple(long_sections + normal_sections)

Versions

Versions
Collecting environment information...
PyTorch version: 2.4.0.dev20240513+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.11.6 (main, Oct 24 2023, 16:49:32) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB

Nvidia driver version: 545.23.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 48 bits virtual
CPU(s):                             112
On-line CPU(s) list:                0-111
Thread(s) per core:                 2
Core(s) per socket:                 28
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              85
Model name:                         Intel(R) Xeon(R) Gold 6238R CPU @ 2.20GHz
Stepping:                           7
CPU MHz:                            1000.000
CPU max MHz:                        4000.0000
CPU min MHz:                        1000.0000
BogoMIPS:                           4400.00
Virtualization:                     VT-x
L1d cache:                          1.8 MiB
L1i cache:                          1.8 MiB
L2 cache:                           56 MiB
L3 cache:                           77 MiB
NUMA node0 CPU(s):                  0-27,56-83
NUMA node1 CPU(s):                  28-55,84-111
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] flake8-bugbear==24.1.17
[pip3] flake8-docstrings==1.7.0
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] onnxruntime==1.17.3
[pip3] onnxscript==0.1.0.dev20240516
[pip3] pytorch-lightning==2.1.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0.dev20240513+cu121
[pip3] torchaudio==2.2.0.dev20240513+cu121
[pip3] torchmetrics==1.3.2
[pip3] torchtyping==0.1.4
[pip3] torchvision==0.19.0.dev20240513+cu121
[pip3] triton==2.2.0
[conda] Could not collect

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 23, 2024
@ezyang
Copy link
Contributor

ezyang commented May 23, 2024

While we can support tensors with variable size, for fundamental reasons we cannot support Python lists with variable sizes. You should see if you can avoid the tolist and split calls and do your operations without generating N tensors

@a-gardner1
Copy link
Contributor Author

That makes sense; I wasn't sure if this was intended to work or not. In this case, I think I can avoid the call to tolist.

I can imagine a way to potentially make this work (e.g., using symbolic lists and tuples), but it would be a major lift if there is not already a structure in place anticipating it.

@ezyang
Copy link
Contributor

ezyang commented May 24, 2024

This is something that I am a bit sensitive to, because there's an overarching problem where people want to do data-dependent-ish computation, and sometimes there aren't convenient abstractions for avoiding the DtoH sync that a tolist() would entail. If you're able to share more details about your use case, that would help.

@a-gardner1
Copy link
Contributor Author

The basic idea is that I have a sparse tensor parameter in a model that I keep decomposed as an nn.ParameterDict of values (with a self-defined BufferDict of indices). Indices are derived from inputs to the forward pass, and each dict is expected to grow during training as new indices are encountered (the indices cannot be known beforehand).

I had been iterating over new indices to add them to the dict, using torch.unique to cut down on the number of iterations. A tolist conversion helped in converting indices to strings that could be used as keys to the dicts.

I believe that I can work around this particular issue relatively easily with some advanced indexing. At the moment, I've simply cut this part of my model out as I work on getting the rest exported to ONNX.

@ezyang ezyang added the module: sparse Related to torch.sparse label May 25, 2024
@ezyang
Copy link
Contributor

ezyang commented May 25, 2024

Thanks for the explanation! Appending rows onto a Tensor is indeed something plain Tensors doesn't do so well at.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes module: sparse Related to torch.sparse oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants