ENH: add top_k#446
Conversation
There was a problem hiding this comment.
Pull request overview
Adds an Array API-compatible top_k wrapper to the array_api_compat.torch namespace, mapping the standardized signature to torch.topk.
Changes:
- Introduces
top_k(a, k, *, axis, mode, **kwargs)wrapper in the Torch backend. - Validates
modeand forwards arguments totorch.topk. - Exports
top_kvia the Torch backend__all__.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
So Next is what to do about older numpies and |
|
Older numpy/cupy: I can see the case either way. Pinging @betatim for scikit-learn needs, and @MaanasArora in case he has good ideas after working on the numpy implementation. |
|
What are the options? My two thoughts/inputs:
|
As far as I understand, for NumPy at least (and the link you referenced) we are using Given that the most expensive underlying operation is the same, I wouldn't expect an outsized performance difference for a compatibility layer. Whether the complexity of manual nan and kth handling is worth supporting older NumPy is a different question. Given the performance should be in a similar order, I'd lean towards supporting if it is very useful for client libraries. |
Add basic
top_kwrappers, per data-apis/array-api#722-tests tracker: data-apis/array-api-tests#438
torchnumpy: works locally with ENH: Addtop_kbased on partitioning numpy/numpy#31659cupydaskSupersedes and closes gh-158
To test locally: with this PR and data-apis/array-api-tests#438, use