# mypy: ignore-errors

import unittest

from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_GPU
from torch.utils._triton import has_triton


requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_gpu = unittest.skipUnless(HAS_GPU, "requires gpu")

if has_triton():
    import triton
    from triton import language as tl

    # Define here so that multiple tests can take advantage of it
    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def sub_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x - y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_kernel_with_optional_param(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        ARGS_PASSED: "tl.constexpr",
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        if ARGS_PASSED == "two":
            y = tl.load(in_ptr1 + offsets, mask=mask)
            output = x + y
        else:
            output = x
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_kernel_with_none_param_and_equal_to_1_arg(
        in_ptr0,
        in_ptr1,  # in_ptr1 could be None
        out_ptr,
        n_elements,
        stride,
        ARGS_PASSED: "tl.constexpr",
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets * stride, mask=mask)
        if ARGS_PASSED == "two":
            y = tl.load(in_ptr1 + offsets, mask=mask)
            output = x + y
        else:
            output = x
        tl.store(out_ptr + offsets * stride, output, mask=mask)

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def sub_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x - y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned_weird_param_order(
        in_ptr0,
        in_ptr1,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
        out_ptr,
    ):
        # out_ptr is after an autotuned param that's declared as tl.constexpr.
        # This param ordering can create bugs if not handled correctly.
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.autotune(
        configs=[
            triton.Config(
                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
            ),
            triton.Config(
                {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
            ),
            triton.Config(
                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
            ),
            triton.Config(
                {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
            ),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_2d_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        x_elements,
        y_elements,
        BLOCK_SIZE_X: "tl.constexpr",
        BLOCK_SIZE_Y: "tl.constexpr",
    ):
        xoffset = tl.program_id(0) * BLOCK_SIZE_X
        xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
        xmask = xindex < x_elements
        yoffset = tl.program_id(1) * BLOCK_SIZE_Y
        yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
        ymask = yindex < y_elements
        x1 = xindex
        y0 = yindex
        tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
        tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
        tmp2 = tmp0 + tmp1
        tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)

    def _dummy_early_config_prune(configs, *_, **__):
        return configs

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
        ],
        key=[],
        warmup=10,
        rep=20,
        prune_configs_by={"early_config_prune": _dummy_early_config_prune},
    )
    @triton.jit
    def add_kernel_autotuned_with_unsupported_args(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_kernel_with_scaling(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        scaling_factor,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = (x + y) * scaling_factor
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_kernel_with_tma_1d_old_api(
        in_desc_ptr0,
        in_desc_ptr1,
        out_desc_ptr,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        offset = pid * BLOCK_SIZE

        a = tl._experimental_descriptor_load(
            in_desc_ptr0,
            [offset],
            [BLOCK_SIZE],
            tl.float32,
        )
        b = tl._experimental_descriptor_load(
            in_desc_ptr1,
            [offset],
            [BLOCK_SIZE],
            tl.float32,
        )

        output = a + b

        tl._experimental_descriptor_store(
            out_desc_ptr,
            output,
            [offset],
        )

    @triton.jit
    def add_kernel_with_tma_2d_old_api(
        in_desc_ptr0,
        in_desc_ptr1,
        out_desc_ptr,
        BLOCK_SIZE_X: "tl.constexpr",
        BLOCK_SIZE_Y: "tl.constexpr",
    ):
        pid_x = tl.program_id(axis=0)
        pid_y = tl.program_id(axis=1)
        offset_x = pid_x * BLOCK_SIZE_X
        offset_y = pid_y * BLOCK_SIZE_Y

        x = tl._experimental_descriptor_load(
            in_desc_ptr0,
            [offset_x, offset_y],
            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
            tl.float32,
        )
        y = tl._experimental_descriptor_load(
            in_desc_ptr1,
            [offset_x, offset_y],
            [BLOCK_SIZE_X, BLOCK_SIZE_Y],
            tl.float32,
        )

        output = x + y

        tl._experimental_descriptor_store(
            out_desc_ptr,
            output,
            [offset_x, offset_y],
        )

    @triton.jit
    def add_kernel_with_tma_1d_new_api(
        in_desc_ptr0,
        in_desc_ptr1,
        out_desc_ptr,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        offset = pid * BLOCK_SIZE

        a = tl.load_tensor_descriptor(
            in_desc_ptr0,
            [offset],
        )
        b = tl.load_tensor_descriptor(
            in_desc_ptr1,
            [offset],
        )

        output = a + b

        tl.store_tensor_descriptor(
            out_desc_ptr,
            [offset],
            output,
        )

    @triton.jit
    def add_kernel_with_tma_2d_new_api(
        in_desc_ptr0,
        in_desc_ptr1,
        out_desc_ptr,
        BLOCK_SIZE_X: "tl.constexpr",
        BLOCK_SIZE_Y: "tl.constexpr",
    ):
        pid_x = tl.program_id(axis=0)
        pid_y = tl.program_id(axis=1)
        offset_x = pid_x * BLOCK_SIZE_X
        offset_y = pid_y * BLOCK_SIZE_Y

        x = tl.load_tensor_descriptor(
            in_desc_ptr0,
            [offset_x, offset_y],
        )
        y = tl.load_tensor_descriptor(
            in_desc_ptr1,
            [offset_x, offset_y],
        )

        output = x + y

        tl.store_tensor_descriptor(
            out_desc_ptr,
            [offset_x, offset_y],
            output,
        )

    @triton.jit
    def add_kernel_on_device_tma_old_api(
        a_ptr,
        b_ptr,
        c_ptr,
        m,
        n,
        workspace,
        BLOCK_SIZE: "tl.constexpr",
    ):
        a_desc_ptr = workspace
        b_desc_ptr = workspace + 128
        c_desc_ptr = workspace + 256
        tl.extra.cuda.experimental_device_tensormap_create2d(
            desc_ptr=a_desc_ptr,
            global_address=a_ptr,
            load_size=[BLOCK_SIZE, BLOCK_SIZE],
            global_size=[m, n],
            element_ty=a_ptr.dtype.element_ty,
        )
        tl.extra.cuda.experimental_device_tensormap_create2d(
            desc_ptr=b_desc_ptr,
            global_address=b_ptr,
            load_size=[BLOCK_SIZE, BLOCK_SIZE],
            global_size=[m, n],
            element_ty=b_ptr.dtype.element_ty,
        )
        tl.extra.cuda.experimental_device_tensormap_create2d(
            desc_ptr=c_desc_ptr,
            global_address=c_ptr,
            load_size=[BLOCK_SIZE, BLOCK_SIZE],
            global_size=[m, n],
            element_ty=c_ptr.dtype.element_ty,
        )

        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
        tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

        pid_x = tl.program_id(axis=0)
        pid_y = tl.program_id(axis=1)
        offset_x = pid_x * BLOCK_SIZE
        offset_y = pid_y * BLOCK_SIZE

        # Load data using the tensor descriptors
        a = tl._experimental_descriptor_load(
            a_desc_ptr,
            [offset_x, offset_y],
            [BLOCK_SIZE, BLOCK_SIZE],
            tl.float32,
        )
        b = tl._experimental_descriptor_load(
            b_desc_ptr,
            [offset_x, offset_y],
            [BLOCK_SIZE, BLOCK_SIZE],
            tl.float32,
        )

        # Perform addition
        output = a + b

        # Store the result
        tl._experimental_descriptor_store(
            c_desc_ptr,
            output,
            [offset_x, offset_y],
        )

    @triton.jit
    def add_kernel_on_device_tma_new_api(
        a_ptr,
        b_ptr,
        c_ptr,
        m,
        n,
        workspace,  # unused but left here to match the old API kernel
        BLOCK_SIZE: "tl.constexpr",
    ):
        # Create tensor descriptors using the new API
        a_desc = tl.make_tensor_descriptor(
            base=a_ptr,
            shape=[m, n],
            strides=[n, 1],
            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
        )
        b_desc = tl.make_tensor_descriptor(
            base=b_ptr,
            shape=[m, n],
            strides=[n, 1],
            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
        )
        c_desc = tl.make_tensor_descriptor(
            base=c_ptr,
            shape=[m, n],
            strides=[n, 1],
            block_shape=[BLOCK_SIZE, BLOCK_SIZE],
        )

        pid_x = tl.program_id(axis=0)
        pid_y = tl.program_id(axis=1)
        offset_x = pid_x * BLOCK_SIZE
        offset_y = pid_y * BLOCK_SIZE

        # Load data using the tensor descriptors with the new API
        a = tl.load_tensor_descriptor(
            a_desc,
            [offset_x, offset_y],
        )
        b = tl.load_tensor_descriptor(
            b_desc,
            [offset_x, offset_y],
        )

        # Perform addition
        output = a + b

        # Store the result with the new API
        tl.store_tensor_descriptor(
            c_desc,
            [offset_x, offset_y],
            output,
        )

    @triton.jit
    def mul2_kernel(
        in_ptr0,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        output = 2 * x
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def mul2_inplace_kernel(
        ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(ptr + offsets, mask=mask)
        output = 2 * x
        tl.store(ptr + offsets, output, mask=mask)

    @triton.jit
    def zero_negs(x):
        return tl.where(x >= 0, x, 0)

    @triton.jit
    def indirection_kernel(
        in_ptr0,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
        ACTIVATION: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        if ACTIVATION == "mul2_inplace_kernel":
            mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
        elif ACTIVATION == "add_kernel":
            add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
        x = tl.load(in_ptr0 + offsets, mask=mask)
        tl.store(out_ptr + offsets, x, mask=mask)

    @triton.jit
    def double_strided_kernel(
        in_ptr,
        out_ptr,
        in_y_stride,
        out_y_stride,
        X_BLOCK_SIZE: "tl.constexpr",
        Y_BLOCK_SIZE: "tl.constexpr",
    ):
        xid = tl.program_id(axis=0)
        yid = tl.program_id(axis=1)
        x_start = xid * X_BLOCK_SIZE
        y_start = yid * Y_BLOCK_SIZE
        x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
        y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
        src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
        dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
        src = tl.load(in_ptr + src_offsets)
        tl.store(out_ptr + dst_offsets, src * 2.0)

    @triton.jit
    def inline_asm_kernel_is_pure_true(
        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
    ):
        x = tl.load(X + tl.arange(0, BLOCK))
        y = tl.load(Y + tl.arange(0, BLOCK))
        s = tl.full([BLOCK], n, tl.int32)
        z = tl.inline_asm_elementwise(
            "shf.l.wrap.b32 $0, $1, $2, $3;",
            "=r,r, r, r",
            [x, y, s],
            dtype=tl.int32,
            is_pure=True,
            pack=1,
        )
        tl.store(Z + tl.arange(0, BLOCK), z)

    @triton.jit
    def inline_asm_kernel_is_pure_false(
        X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"
    ):
        x = tl.load(X + tl.arange(0, BLOCK))
        y = tl.load(Y + tl.arange(0, BLOCK))
        s = tl.full([BLOCK], n, tl.int32)
        z = tl.inline_asm_elementwise(
            "shf.l.wrap.b32 $0, $1, $2, $3;",
            "=r,r, r, r",
            [x, y, s],
            dtype=tl.int32,
            is_pure=False,
            pack=1,
        )
        tl.store(Z + tl.arange(0, BLOCK), z)

    @triton.jit
    def add_kernel_with_block_ptr(
        x_ptr,
        y_ptr,
        output_ptr,
        n_elements,
        BLOCK_SIZE: tl.constexpr,
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        x = tl.load(
            tl.make_block_ptr(
                base=x_ptr,
                shape=[n_elements],
                strides=[1],
                offsets=[block_start],
                block_shape=[BLOCK_SIZE],
                order=[0],
            ),
            boundary_check=[0],
        )
        y = tl.load(
            tl.make_block_ptr(
                base=y_ptr,
                shape=[n_elements],
                strides=[1],
                offsets=[block_start],
                block_shape=[BLOCK_SIZE],
                order=[0],
            ),
            boundary_check=[0],
        )
        output = x + y
        tl.store(
            tl.make_block_ptr(
                base=output_ptr,
                shape=[n_elements],
                strides=[1],
                offsets=[block_start],
                block_shape=[BLOCK_SIZE],
                order=[0],
            ),
            output,
            boundary_check=[0],
        )

    @triton.jit
    def kernel_with_block_ptr_2d(
        x_ptr,
        output_ptr,
        n_elements,
        BLOCK_SIZE: tl.constexpr,
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        x = tl.load(
            tl.make_block_ptr(
                base=x_ptr,
                shape=[n_elements, 1],
                strides=[1, 1],
                offsets=[block_start, 0],
                block_shape=[BLOCK_SIZE, 1],
                order=[1, 0],
            ),
            boundary_check=[0],
        )
        output = x
        tl.store(
            tl.make_block_ptr(
                base=output_ptr,
                shape=[n_elements, 1],
                strides=[1, 1],
                offsets=[block_start, 0],
                block_shape=[BLOCK_SIZE, 1],
                order=[1, 0],
            ),
            output,
            boundary_check=[0],
        )

    from triton.language import load, store

    @triton.jit
    def add_kernel_with_import(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = load(in_ptr0 + offsets, mask=mask)
        y = load(in_ptr1 + offsets, mask=mask)
        output = x + y
        store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def cond_op_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        if tl.program_id(0) == 0:
            output = x + y
        else:
            output = x * y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def atomic_add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.atomic_add(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_4_times_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        for i in range(2):
            output = x + y
            tl.store(out_ptr + offsets, output, mask=mask)
        i = 2
        while i > 0:
            i -= 1
            output = x + y
            tl.store(out_ptr + offsets, output, mask=mask)

    @triton.jit
    def add_kernel_out_of_order_fn2(
        in_ptr0,
        in_ptr1,
        n_elements,
        out_ptr,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @triton.autotune(
        configs=[
            triton.Config(
                {
                    "BLOCK_SIZE_M": 16,
                    "BLOCK_SIZE_N": 16,
                    "BLOCK_SIZE_K": 16,
                    "GROUP_SIZE_M": 4,
                },
                num_stages=4,
                num_warps=4,
            ),
            triton.Config(
                {
                    "BLOCK_SIZE_M": 128,
                    "BLOCK_SIZE_N": 64,
                    "BLOCK_SIZE_K": 32,
                    "GROUP_SIZE_M": 8,
                },
                num_stages=4,
                num_warps=4,
            ),
        ],
        key=["M_ptr", "N", "K"],
    )
    @triton.jit
    def strange_config_matmul_kernel(
        a_ptr,
        b_ptr,
        c_ptr,
        M_ptr,
        N,
        K,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
        BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
    ):
        # This is a simplified matmul from Triton tutorial.
        pid = tl.program_id(axis=0)
        M = tl.load(M_ptr)
        if M == 0 and BLOCK_SIZE_M > 32:
            # This will run the full matmul if BLOCK_SIZE_M > 32
            M = 4096
        elif M == 0:
            # This directly returns, which will cut short the bad config of 16-block size.
            return
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        num_pid_in_group = GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * GROUP_SIZE_M
        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        a_ptrs = a_ptr + (offs_am[:, None] + offs_k[None, :])
        b_ptrs = b_ptr + (offs_k[:, None] + offs_bn[None, :])

        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
            accumulator = tl.dot(a, b, accumulator)
            a_ptrs += BLOCK_SIZE_K
            b_ptrs += BLOCK_SIZE_K
        c = accumulator.to(tl.float16)

        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        c_ptrs = c_ptr + offs_cm[:, None] + offs_cn[None, :]
        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
        tl.store(c_ptrs, c, mask=c_mask)

    @triton.jit
    def kernel_with_docstring_double_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
        """
        This kernel contains a triple-quote docstring w/ double quotes.
        Make sure that codegen sanitizes the docstring.
        """
        pid = tl.program_id(axis=0)
        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
        tl.store(out_ptr + offsets, ones, mask=offsets < numel)

    @triton.jit
    def kernel_with_docstring_single_quotes(out_ptr, numel, BLOCK_SIZE: tl.constexpr):
        '''
        This kernel contains a triple-quote docstring w/ single quotes
        Make sure that codegen sanitizes the docstring.
        To prevent it from being linted to double quotes: """!!!"""
        '''
        pid = tl.program_id(axis=0)
        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
        ones = tl.full([BLOCK_SIZE], 1.0, dtype=tl.float32)
        tl.store(out_ptr + offsets, ones, mask=offsets < numel)

    @triton.jit
    def kernel_inline_asm_double_quotes(
        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
    ):
        pid = tl.program_id(axis=0)
        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
        data = tl.load(in_ptr + offsets, mask=offsets < numel)
        cos_pow = tl.inline_asm_elementwise(
            asm="""
            {
                cos.approx.f32 $0, $1;
                ex2.approx.f32 $0, $0;
            }
                """,
            constraints=("=r, r"),
            args=[data],
            dtype=tl.float32,
            is_pure=True,
            pack=1,
        )
        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)

    @triton.jit
    def kernel_inline_asm_single_quotes(
        in_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr
    ):
        pid = tl.program_id(axis=0)
        offsets = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
        data = tl.load(in_ptr + offsets, mask=offsets < numel)
        cos_pow = tl.inline_asm_elementwise(
            asm='''
            {
                // double quotes to pacify the linter """!!!"""
                cos.approx.f32 $0, $1;
                ex2.approx.f32 $0, $0;
            }
                ''',
            constraints=("=r, r"),
            args=[data],
            dtype=tl.float32,
            is_pure=True,
            pack=1,
        )
        tl.store(out_ptr + offsets, cos_pow, mask=offsets < numel)

    # support the old (experimental) and new (tensor_descriptor) APIs
    def create_tensor_descriptor_shim(
        tensor, block_sizes: list[int], new_api: bool = True
    ):
        if new_api:
            return triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
                tensor, block_sizes
            )
        else:
            if len(block_sizes) == 1:
                return triton.tools.experimental_descriptor.create_1d_tma_descriptor(
                    tensor.data_ptr(),
                    tensor.size(0),
                    block_sizes[0],
                    tensor.element_size(),
                )
            else:
                assert len(block_sizes) == 2
                return triton.tools.experimental_descriptor.create_2d_tma_descriptor(
                    tensor.data_ptr(),
                    tensor.size(0),
                    tensor.size(1),
                    block_sizes[0],
                    block_sizes[1],
                    tensor.element_size(),
                )
