import math
import torch
import torch.nn as nn
import triton
import triton.language as tl


def basis_tetra_regular_to_fourier():
    # Construct change of basis from regular rep of A4 to its isotypical decomp in irreps.
    # Each row here is an orthogonal function on A4, obtained as matrix elements of the irreps.
    # We normalize the rows after the construction to get an orthonormal basis.
    # The first row is the invariant part.
    # The second and third rows correspond to a 2D real irrep that splits into two complex valued conjugate irreps.
    # The final nine rows are three copies of the 3D real irrep (that is still an irrep over complex numbers).
    # This 3D irrep is the faithful representation of A4 as 3D rotation matrices acting on the tetrahedron.
    # Note that the ordering of the rows below is (annoyingly) important to get the irreps sorted into block diagonal form.
    c = -0.5  # cos(2pi/3)
    s = 0.5 * math.sqrt(3)  # sin(2pi/3)
    basis_regular_to_fourier = torch.tensor([
        [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 1,  1,  1,  1,  c,  c,  c,  c,  c,  c,  c,  c],
        [ 0,  0,  0,  0,  s, -s,  s, -s,  s, -s,  s, -s],
        [ 1,  1, -1, -1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  1,  0,  1,  0, -1,  0, -1,  0],
        [ 0,  0,  0,  0,  0,  1,  0,  1,  0, -1,  0, -1],
        [ 0,  0,  0,  0,  0,  1,  0, -1,  0, -1,  0,  1],
        [ 1, -1,  1, -1,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  1,  0, -1,  0, -1,  0,  1,  0],
        [ 0,  0,  0,  0,  1,  0, -1,  0,  1,  0, -1,  0],
        [ 0,  0,  0,  0,  0,  1,  0, -1,  0,  1,  0, -1],
        [ 1, -1, -1,  1,  0,  0,  0,  0,  0,  0,  0,  0],
    ], dtype=torch.float32)
    basis_regular_to_fourier = basis_regular_to_fourier / torch.norm(basis_regular_to_fourier, dim=1, keepdim=True)
    return basis_regular_to_fourier


class ToTetraFourier(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('basis_regular_to_fourier', basis_tetra_regular_to_fourier())

    def forward(self, x):
        x_shape = x.shape
        x_f = torch.einsum(
            'ij,...jc->...ic',
            self.basis_regular_to_fourier,
            x.view(*x_shape[:-1], 12, x_shape[-1]//12),
        )
        x1, x2, x3 = torch.split(x_f, [1, 2, 9], dim=-2)
        return x1, x2.flatten(start_dim=-2), x3.reshape(*x_shape[:-1], 3, x_shape[-1]//4)


class FromTetraFourier(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('basis_fourier_to_regular', basis_tetra_regular_to_fourier().mT)

    def forward(self, x1, x2, x3):
        x1_shape = x1.shape
        return torch.einsum(
            'ij,...jc->...ic',
            self.basis_fourier_to_regular, 
            torch.cat([
                x1,
                x2.view(*x1_shape[:-2], 2, x1_shape[-1]),
                x3.view(*x1_shape[:-2], 9, x1_shape[-1]),
            ], dim=-2),
        ).flatten(start_dim=-2)


class ToTetraFourierQuarterBatch(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('basis_regular_to_fourier', basis_tetra_regular_to_fourier())

    def forward(self, x):
        batch_dims = x.shape[:-1]
        channels = x.shape[-1]//12
        # return torch.einsum(
        #     "ij,...jk->...ik",
        #     self.basis_regular_to_fourier,
        #     x.view(*batch_dims, 12, channels),
        # ).reshape(*batch_dims, 4, 3*channels)
        # TODO: Below is faster for some dumb reason... Why?
        return torch.matmul(
            self.basis_regular_to_fourier,
            x.view(*batch_dims, 12, channels),
        ).reshape(*batch_dims, 4, 3*channels)


class FromTetraFourierQuarterBatch(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('basis_fourier_to_regular', basis_tetra_regular_to_fourier().mT.contiguous())

    def forward(self, x):
        batch_dims = x.shape[:-2]
        channels = x.shape[-1]//3
        # return torch.einsum(
        #     "ij,...jk->...ik",
        #     self.basis_fourier_to_regular,
        #     x.reshape(*batch_dims, 12, channels),
        # ).flatten(start_dim=-2)
        # TODO: Below is faster for some dumb reason... Why?
        return torch.matmul(
            self.basis_fourier_to_regular,
            x.reshape(*batch_dims, 12, channels),
        ).flatten(start_dim=-2)


def coo_basis_tetra_regular_to_fourier():
    a = 1 / math.sqrt(12)
    b = 1 / math.sqrt(6)
    c = -1 / math.sqrt(24)
    d = 1 / math.sqrt(8)
    e = 0.5
    return torch.sparse_coo_tensor(
        indices=[[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
                  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
                  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,  4,  5,  5,
                  5,  5,  6,  6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  8,
                  9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 11],
                 [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  1,
                  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  4,  5,  6,  7,
                  8,  9, 10, 11,  0,  1,  2,  3,  4,  6,  8, 10,  5,  7,
                  9, 11,  5,  7,  9, 11,  0,  1,  2,  3,  4,  6,  8, 10,
                  4,  6,  8, 10,  5,  7,  9, 11,  0,  1,  2,  3]],
        values=[ a,  a,  a,  a,  a,  a,  a,  a,  a,  a,  a,  a,
                b,  b,  b,  b,  c,  c,  c,  c,  c,  c,  c,  c,
                d, -d,  d, -d,  d, -d,  d, -d,  e,  e, -e, -e,
                e,  e, -e, -e,  e,  e, -e, -e,  e, -e, -e,  e,
                e, -e,  e, -e,  e, -e, -e,  e,  e, -e,  e, -e,
                e, -e,  e, -e,  e, -e, -e,  e],
    )


def sparse_kron(input: torch.Tensor, other: torch.Tensor):
    # from https://github.com/pytorch/pytorch/issues/134069
    assert input.ndim == 2  # Added these to be safe
    assert other.ndim == 2
    input_indices = input.indices()
    other_indices = other.indices()

    input_indices_expanded = input_indices.expand(other_indices.shape[1], *input_indices.shape).permute(2, 1, 0) * torch.tensor(other.shape).reshape(1,-1,1)
    other_indices_expanded = other_indices.expand(input_indices.shape[1], *other_indices.shape)
    new_indices = (input_indices_expanded + other_indices_expanded).permute(1,0,2).reshape(input.ndim,-1)

    new_values = torch.kron(input.values(), other.values())

    if new_indices.ndim == 1:
        new_indices = new_indices.reshape([input.ndim, 0])

    new_shape = [n * m for n, m in zip(input.shape, other.shape)]

    return torch.sparse_coo_tensor(new_indices, new_values, new_shape, dtype=input.dtype, device=input.device)  # Removed is_coalesced=True because that is not true. Could coalesce here but doing it outside instead.


def coo_multi_channel_basis_tetra_regular_to_fourier(channels):
    basis_regular_to_fourier = coo_basis_tetra_regular_to_fourier()
    identity = torch.sparse_coo_tensor(indices=[range(channels), range(channels)], values=torch.ones(channels))
    return sparse_kron(basis_regular_to_fourier.coalesce(), identity.coalesce()).coalesce()


class SparseToTetraFourier(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.register_buffer(
            'basis',
            # Transpose because we apply it from the right:
            coo_multi_channel_basis_tetra_regular_to_fourier(channels).mT.coalesce(),
        )

    def forward(self, x):
        x_shape = x.shape
        x_f = x @ self.basis
        x1, x2, x3 = torch.split(x_f, [self.channels, 2*self.channels, 9*self.channels], dim=-1)
        return x1, x2, x3.view(*x_shape[:-1], 3, 3*self.channels)


class SparseFromTetraFourier(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.register_buffer(
            'basis',
            # No transpose because we apply it from the right:
            coo_multi_channel_basis_tetra_regular_to_fourier(channels).coalesce(),
        )

    def forward(self, x1, x2, x3):
        x1_shape = x1.shape
        if len(x1_shape) == 4:
            x1 = x1.squeeze(-2)
        return torch.cat([
            x1,
            x2,
            x3.flatten(start_dim=-2),
        ], dim=-1) @ self.basis


if __name__ == '__main__':
    # --- Setup (Use your actual model and data) ---
    from .linear_fourier import TetraFourierLinear
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_float32_matmul_precision("high")
    C = 64
    model = torch.nn.Linear(12*C, 12*C).to(device)
    model2 = TetraFourierLinear(12*C, 12*C).to(device)
    model3 = TetraFourierLinear(12*C, 12*C, transform_to_fourier=False, transform_back_from_fourier=False).to(device)
    input_tensor = torch.randn(256**2, 12*C, device=device) # Example input
    input_fourier = model2.to_fourier(input_tensor)

    # --- WARM-UP RUNS ---
    # Run both versions once to warm up the GPU and cache
    _ = model(input_tensor) 
    _ = model2(input_tensor)
    _ = model3(input_fourier)
    torch.cuda.synchronize()

    print("--- Profiling model 1 Version ---")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True
    ) as prof:
        with torch.profiler.record_function("model1"):
            for _ in range(10): # Run multiple times for a stable measurement
                _ = model(input_tensor)
                torch.cuda.synchronize() # Wait for the GPU to finish

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    prof.export_chrome_trace("model1_trace.json")

    print("\n--- Profiling model 2 Version ---")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True
    ) as prof:
        with torch.profiler.record_function("model2"):
            for _ in range(10):
                _ = model2(input_tensor)
                torch.cuda.synchronize()

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    prof.export_chrome_trace("model2_trace.json")


    print("\n--- Profiling model 3 Version ---")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True
    ) as prof:
        with torch.profiler.record_function("model3"):
            for _ in range(10):
                _ = model3(input_fourier)
                torch.cuda.synchronize()

    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    prof.export_chrome_trace("model3_trace.json")


