jax-flash-attn2

View on PyPIReverse Dependencies (0)

0.0.1 jax_flash_attn2-0.0.1-py3-none-any.whl

Wheel Details

Project: jax-flash-attn2
Version: 0.0.1
Filename: jax_flash_attn2-0.0.1-py3-none-any.whl
Download: [link]
Size: 42759
MD5: 8d7ca7e9095345343bca1488389d2743
SHA256: 161f2baf1bc3a11e80fa30717521769267c5840cabb39af2b5b012f9e1e0ebdb
Uploaded: 2024-10-23 22:37:12 +0000

dist-info

METADATA

Metadata-Version: 2.1
Name: jax-flash-attn2
Version: 0.0.1
Summary: Flash Attention Implementation with Multiple Backend Support and Sharding This module provides a flexible implementation of Flash Attention with support for different backends (GPU, TPU, CPU) and platforms (Triton, Pallas, JAX).
Author: Erfan Zare Chavoshi
Author-Email: erfanzare810[at]gmail.com
Home-Page: https://github.com/erfanzar/jax-flash-attn2
Project-Url: Documentation, https://erfanzar.github.io/jax-flash-attn2
Project-Url: Repository, https://github.com/erfanzar/jax-flash-attn2
License: Apache-2.0
Keywords: JAX,Deep Learning,Machine Learning,XLA
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: chex
Requires-Dist: einops
Requires-Dist: jax (>=0.4.33)
Requires-Dist: jaxlib (>=0.4.33)
Requires-Dist: scipy (==1.13.1)
Requires-Dist: triton (<3.1.0,>=3.0.0)
Description-Content-Type: text/markdown
[Description omitted; length: 3999 characters]

WHEEL

Wheel-Version: 1.0
Generator: poetry-core 1.9.1
Root-Is-Purelib: true
Tag: py3-none-any

RECORD

Path Digest Size
jax_flash_attn2/__init__.py sha256=C1eSdLickhtGrK2x6krBlbclyLfbstL9InsJrfOVPxI 889
jax_flash_attn2/_custom_call_lib/__init__.py sha256=yx0qJwBwFBxBG1hWnuW-vAS2Xt3SFyLMbPjdWOO-0mU 804
jax_flash_attn2/_custom_call_lib/lib.py sha256=nz3NEhnclNU9db11Obwo2ZIEzUNBQTgoOX1pWOncxMQ 19343
jax_flash_attn2/cpu_calls/__init__.py sha256=1fYLO_yep4MJzBNJ28z9WuvVHVJCqmFV-9d4Soq01wE 713
jax_flash_attn2/cpu_calls/mha.py sha256=41-kEZ6MNNLVcSCznwq2kYJ6oJO48FNTJTAuXuCzVmc 13981
jax_flash_attn2/flash_attention.py sha256=AIuCQuDDpAY4lv1jc0fqD5OQFhxu6HclhZmiRkAD8bc 10496
jax_flash_attn2/pallas_kernels/__init__.py sha256=QTfZ3rLLfaJPEYZWjSEpe7GIRtc9BgZanMsNr3XQeO0 746
jax_flash_attn2/pallas_kernels/gpu_mha_kernel.py sha256=-L1Yigme8wxcPTZOGrkEpaCAADrrsmwi6DUpMgw2tLA 13095
jax_flash_attn2/triton_kernels/__init__.py sha256=b41K8LlJ2bqYAvp8yJiBkfV2b-AuyfBtbbmSrhyAzXo 855
jax_flash_attn2/triton_kernels/gqa_kernel.py sha256=N0AgDRK_t85eTO27PNx81BDfuzPKBEA7p2GBNdDmiFM 31694
jax_flash_attn2/triton_kernels/mha_kernel.py sha256=bwInYYjHEl9reStG3tl2T1n39uQ4MYb1BLhnk2dPbMk 31257
jax_flash_attn2/utils.py sha256=gVFj0bJ8ilZCzYbEJJ8hyITfB4S05bMdszPZedUp0C8 3572
jax_flash_attn2-0.0.1.dist-info/LICENSE sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ 11357
jax_flash_attn2-0.0.1.dist-info/METADATA sha256=4J69pSBBpHeGFcdNd_GKLUBAmxe3hLK8cdGcgpabWzk 5410
jax_flash_attn2-0.0.1.dist-info/WHEEL sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs 88
jax_flash_attn2-0.0.1.dist-info/RECORD