Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT][IFRT] Move topology discovery into PJRT-IFRT. #68260

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

[PJRT][IFRT] Move topology discovery into PJRT-IFRT.

Currently each PJRT backend is responsible for figuring out the global topology. This has several downsides:

  • it is a layering violation: PJRT is responsible for communicating with the devices local to each host, whereas IFRT is the layer that has a global view of a cluster. Topology discovery properly belongs at the IFRT layer. The current topology discovery is really useful only to the PJRT-IFRT implementation of the IFRT API, because other distributed runtimes will have their own way to discover their cluster topology, and hence it should go there.
  • duplication: e.g., the CPU and GPU backends have almost identical code to discover topologies. If we move this code into PJRT-IFRT, we can avoid redundancy.
  • inconsistency: currently the TPU implementation of the PJRT API does its own topology discovery. But this leads to an inconsistency: the process_id values attached to each device are determined by the TPU physical topology, and they don't match the process_id that the user otherwise sees in JAX. While it is easy for each PJRT client to figure out its own process_id, in order to correctly determine the mapping between non-addressable TPUs and their process indices, we would need an additional layer of topology exchange. This change allows a single source of truth for topology exchange.

This change adds support for topology discover to PJRT-IFRT, and migrates the CPU PJRT plugin to use it instead of performing its own topology discovery. Future changes will migrate other PJRT plugins.

One notable effect of this change is that we no longer assign device IDs in contiguous fashion. It's up to each PJRT plugin to choose globally unique device IDs, but after this change the CPU plugin forms a global device ID as (process_id << 17 | local_device_id). It is hard to say for sure if any code is relying on contiguity of the device ID space, but it is not something we've ever documented as a contract of PJRT or JAX.

Another downside of this change is that on CPU and GPU PjRtTopologyDescription may no longer have a complete description of the cluster topology since it is derived from a single PJRT client's view of the cluster, and after this change that view only contains local devices. We leave fixing this to a future change.

Currently each PJRT backend is responsible for figuring out the global topology. This has several downsides:
* it is a layering violation: PJRT is responsible for communicating with the devices local to each host, whereas IFRT is the layer that has a global view of a cluster. Topology discovery properly belongs at the IFRT layer. The current topology discovery is really useful only to the PJRT-IFRT implementation of the IFRT API, because other distributed runtimes will have their own way to discover their cluster topology, and hence it should go there.
* duplication: e.g., the CPU and GPU backends have almost identical code to discover topologies. If we move this code into PJRT-IFRT, we can avoid redundancy.
* inconsistency: currently the TPU implementation of the PJRT API does its own topology discovery. But this leads to an inconsistency: the process_id values attached to each device are determined by the TPU physical topology, and they don't match the process_id that the user otherwise sees in JAX. While it is easy for each PJRT client to figure out its own process_id, in order to correctly determine the mapping between non-addressable TPUs and their process indices, we would need an additional layer of topology exchange. This change allows a single source of truth for topology exchange.

This change adds support for topology discover to PJRT-IFRT, and migrates the CPU PJRT plugin to use it instead of performing its own topology discovery. Future changes will migrate other PJRT plugins.

One notable effect of this change is that we no longer assign device IDs in contiguous fashion. It's up to each PJRT plugin to choose globally unique device IDs, but after this change the CPU plugin forms a global device ID as `(process_id << 17 | local_device_id)`. It is hard to say for sure if any code is relying on contiguity of the device ID space, but it is not something we've ever documented as a contract of PJRT or JAX.

Another downside of this change is that on CPU and GPU PjRtTopologyDescription may no longer have a complete description of the cluster topology since it is derived from a single PJRT client's view of the cluster, and after this change that view only contains local devices. We leave fixing this to a future change.

PiperOrigin-RevId: 631781607
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant