import time
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl


class FastMatryoshkaLinear(nn.Module):
    def __init__(
        self,
        dim: int = 4096,
        hidden_dim: int = None,
        matryoshka_setting: List[int] = [1, 1, 2, 4, 8, 16],
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim or dim
        self.total_heads, self.matryoshka_heads = self._setup_dimensions(
            dim, self.hidden_dim, matryoshka_setting
        )

        input_dims = self.block_dims.sum().item()
        self.weight = nn.Parameter(torch.zeros(input_dims, self.output_head_dim))
        nn.init.xavier_normal_(self.weight)

    def _setup_dimensions(
        self, dim: int, hidden_dim: int, matryoshka_setting: List[int]
    ):
        total_heads = sum(matryoshka_setting)
        assert (
            dim % total_heads == 0
        ), f"The sum of {matryoshka_setting} should be divisible by {dim}."
        assert (
            hidden_dim % total_heads == 0
        ), f"The sum of {matryoshka_setting} should be divisible by {hidden_dim}."
        self.input_head_dim = dim // total_heads
        self.output_head_dim = hidden_dim // total_heads

        self.output_dims = torch.tensor(matryoshka_setting, dtype=torch.long)
        self.input_dims = (
            torch.cumsum(torch.tensor(matryoshka_setting, dtype=torch.long), dim=0)
            * self.input_head_dim
        )
        self.block_dims = self.input_dims * self.output_dims

        return total_heads, len(matryoshka_setting)

    def pytorch_forward(
        self, x: torch.Tensor, matryoshka_head: int = None
    ) -> torch.Tensor:
        matryoshka_head = matryoshka_head or self.matryoshka_heads
        batch_dims = x.shape[:-1]
        outputs = torch.zeros(
            *batch_dims, self.hidden_dim, device=x.device, dtype=x.dtype
        )

        input_offset, output_offset = 0, 0
        for i in range(matryoshka_head):
            input_dim = self.input_dims[i].item()
            output_dim = self.output_dims[i].item()

            # Use F.linear for efficient matrix multiplication
            for j in range(output_dim):
                out = F.linear(
                    x[..., :input_dim],
                    self.weight[input_offset : input_offset + input_dim].T,
                )

                # In-place addition to avoid creating intermediate tensors
                outputs[..., output_offset : output_offset + self.output_head_dim].add_(
                    out.view(*batch_dims, -1)
                )

                output_offset += self.output_head_dim
                input_offset += input_dim

        return outputs

    def optimized_forward(
        self, x: torch.Tensor, matryoshka_head: int = None
    ) -> torch.Tensor:
        matryoshka_head = matryoshka_head or self.matryoshka_heads
        batch_dims = x.shape[:-1]
        outputs = torch.zeros(
            *batch_dims, self.hidden_dim, device=x.device, dtype=x.dtype
        )

        weight_offset = 0
        output_offset = 0

        for i in range(matryoshka_head):
            input_dim = self.input_dims[i].item()
            output_dim = self.output_dims[i].item()
            total_weight_size = input_dim * output_dim
            # Extract the relevant weight slices and reshape
            weight_slices = self.weight[
                weight_offset : weight_offset + total_weight_size
            ]
            weight_slices = weight_slices.view(
                input_dim, output_dim, self.output_head_dim
            )  # Shape: (input_dim, output_dim, output_head_dim)

            weight_slices = weight_slices.view(
                output_dim, input_dim, self.output_head_dim
            )  # Shape: (input_dim, output_dim, output_head_dim)

            # Slice the input tensor
            x_slice = x[..., :input_dim]  # Shape: (..., input_dim)
            # outputs_i = torch.matmul(x_slice, weight_slices)

            # Perform the tensor contraction using torch.einsum
            outputs_i = torch.einsum(
                "...i,oih->...oh", x_slice, weight_slices
            )  # Shape: (..., output_dim, output_head_dim)

            # Reshape outputs_i to match the outputs tensor
            outputs_i = outputs_i.reshape(
                *batch_dims, -1
            )  # Shape: (..., output_dim * output_head_dim)

            # Update the outputs tensor
            outputs[..., output_offset : output_offset + outputs_i.shape[-1]].add_(
                outputs_i
            )

            # Update offsets
            weight_offset += total_weight_size
            output_offset += outputs_i.shape[-1]

        return outputs


class MatryoshkaLinearTriton(nn.Module):
    def __init__(
        self,
        dim: int = 4096,
        matryoshka_setting: List[int] = [1, 1, 2, 4, 8, 16],
        bias=False,
    ):
        super().__init__()
        num_head = sum(matryoshka_setting)
        assert (
            dim % num_head == 0
        ), "The sum of Matryoshka head should be divided by dim."
        head_dim = dim // num_head

        self.output_dims = np.array(matryoshka_setting) * head_dim
        self.input_dims = np.cumsum(self.output_dims)
        self.dim = dim
        self.num_layers = len(self.output_dims)

        self.output_starts = np.cumsum([0] + self.output_dims[:-1].tolist())
        self.output_ends = self.output_starts + self.output_dims

        # Initialize a large weight matrix
        self.weight = nn.Parameter(torch.zeros(dim, dim))
        if bias:
            self.bias_param = nn.Parameter(torch.zeros(dim))
        else:
            self.bias_param = torch.zeros(dim)
            self.bias_param.requires_grad = False  # No gradients if bias is False

        # Initialize weights and biases for each layer
        for i in range(self.num_layers):
            out_start = self.output_starts[i]
            out_end = self.output_ends[i]
            in_end = self.input_dims[i]
            outdim = self.output_dims[i]
            indim = self.input_dims[i]

            # Initialize the weight matrix for this layer
            W_i = torch.empty((outdim, indim))
            nn.init.xavier_uniform_(W_i)
            self.weight.data[out_start:out_end, :in_end] = W_i

            # Initialize the bias for this layer if bias is True
            if bias:
                b_i = torch.zeros(outdim)
                self.bias_param.data[out_start:out_end] = b_i

    def pytorch_forward(self, x):
        x = (
            torch.cat(
                [
                    torch.matmul(
                        x[..., :indim], self.weight[out_start:out_end, :indim].T
                    )
                    for indim, out_start, out_end in zip(
                        self.input_dims, self.output_starts, self.output_ends
                    )
                ],
                dim=-1,
            )
            + self.bias_param
        )
        return x


if __name__ == "__main__":
    dim = 4096
    matryoshka_setting = [1, 1, 2, 4, 8, 16]
    batch_size = 32
    dtype = torch.float16

    # Initialize model
    model_v1 = (
        MatryoshkaLinearTriton(dim, matryoshka_setting, bias=True).to("cuda").to(dtype)
    )
    model_v2 = (
        FastMatryoshkaLinear(dim, matryoshka_setting=matryoshka_setting)
        .to("cuda")
        .to(dtype)
    )
    model_linear = nn.Linear(dim, dim).to("cuda").to(dtype)

    # Input tensor
    x = torch.randn(batch_size, 40960, dim, device="cuda", dtype=dtype)

    # Forward pass
    test_time = 10
    pytorch_v1_time, pytorch_v21_time, pytorch_v22_time, pytorch_linear_time = (
        0,
        0,
        0,
        0,
    )
    # output_triton = model(x)

    for i in range(test_time):
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            output_pytorch = model_v1.pytorch_forward(x)
        torch.cuda.synchronize()
        pytorch_v1_time += time.time() - start
    print("Pytorch V1 Time", pytorch_v1_time / test_time)

    for i in range(test_time):
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            output_pytorch = model_v2.pytorch_forward(x)
            # output_21 = model_v2.pytorch_forward(x).cpu()
            # output_22 = model_v2.optimized_forward(x).cpu()
            # print(torch.allclose(output_21, output_22, atol=1e-2, rtol=1e-2))

        torch.cuda.synchronize()
        pytorch_v21_time += time.time() - start
    print("Pytorch V2.1 Time", pytorch_v21_time / test_time)

    for i in range(test_time):
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            output_pytorch = model_v2.optimized_forward(x)
        torch.cuda.synchronize()
        pytorch_v22_time += time.time() - start
    print("Pytorch V2.2 Time", pytorch_v22_time / test_time)

    for i in range(test_time):
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            output_pytorch = model_linear(x)
        torch.cuda.synchronize()
        pytorch_linear_time += time.time() - start
    print("Pytorch Linear Time", pytorch_linear_time / test_time)
    # print(torch.allclose(output_triton, output_pytorch, atol=1e-2, rtol=1e-2))  # Expected: True
