import torch
import torch.nn.functional as F
from torch.nn import LayerNorm as RMSNorm  # Replace with actual RMSNorm if needed
import math
import pdb
import tqdm
import time
from torch.nn import Module
from torch import nn, einsum, Tensor
from components import ensure_batched, _kaiming_init, _kaiming_init_bias
#from .helpers import *
from einops import rearrange, pack, unpack, repeat
#import GPUtil as GPU  # For GPU utilization
#from fvcore.nn import FlopCountAnalysis  # For FLOPs
import time


# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def divisible_by(num, den):
    return (num % den) == 0

def chunk_num(num, chunks):
    num_per_chunk, remainder = divmod(num, chunks)

    out = []
    for i in range(chunks):
        n = num_per_chunk
        out.append(n + int(i < remainder))

    return out

def pack_one(t, pattern):
    return pack([t], pattern)

def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]

def l2norm(t):
    return F.normalize(t, dim = - 1)

def cumsum_exclusive(t, dim = -3):
    assert dim < 0
    num_pad_dims = -dim - 1
    pre_padding = (0, 0) * num_pad_dims
    return F.pad(t, (*pre_padding, 1, -1)).cumsum(dim = dim)

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


# norm

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

class RMSNorm(Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = dim ** 0.5
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return l2norm(x) * self.scale * self.gamma

# expert

def FeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult)
    return nn.Sequential(
        nn.Linear(dim, dim_hidden),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )

class GEGLU(Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.gelu(gate)

def GLUFeedForward(
    dim,
    mult = 4,
    dropout = 0.
):
    dim_hidden = int(dim * mult * 2 / 3)

    return nn.Sequential(
        nn.Linear(dim, dim_hidden * 2),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim)
    )




class FactorizedLinear(nn.Module):
    def __init__(self, in_dim, out_dim, rank, num_experts, num_heads, share_weights=True):
        super(FactorizedLinear, self).__init__()
        self.share_weight = share_weights

        # Shared weights (either shared across experts or expert-specific)
        if share_weights:
            # Shape: (H, D, R) — Shared across experts
            self.shared_weight = nn.Parameter(torch.randn(num_heads, in_dim, rank))
            # orthogonal
        else:
            # Shape: (E, H, D, R) — Specific to each expert and head
            self.shared_weight = nn.Parameter(torch.randn(num_experts, num_heads, in_dim, rank))

        # Expert-specific weights and biases
        self.expert_weights = nn.Parameter(torch.randn(num_experts, num_heads, rank, out_dim))

        self.bias = nn.Parameter(torch.zeros(num_experts, num_heads, out_dim))  # (E, H, O)

        _kaiming_init_bias(self.expert_weights, self.bias)
        nn.init.orthogonal_(self.shared_weight.to(torch.float32))
        _kaiming_init(self.shared_weight)
        nn.init.orthogonal_(self.expert_weights.to(torch.float32))
        _kaiming_init(self.expert_weights)

    def forward(self, x):
        # x is shaped (S, E, H, D), where:
        #  E -> number of experts
        #  S -> batch size
        #  H -> number of heads
        #  D -> input dimension per head
        was_unbatched = len(x.shape) == 3
        if was_unbatched:
            x = x.unsqueeze(0)
        x = rearrange(x, 'e s h d -> s e h d')

        # Shared component calculation
        if self.share_weight:
            # If weights are shared across experts, we can broadcast the shared weight
            # x: (S, E, H, D) @ shared_weight: (H, D, R) -> shared_component: (S, E, H, R)
            shared_weight = self.shared_weight.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, D, R)
            shared_component = torch.matmul(x.unsqueeze(-2), shared_weight).squeeze(-2)  # (S, E, H, R)
        else:
            # If weights are specific to each expert
            # x: (S, E, H, D) @ shared_weight: (E, H, D, R) -> shared_component: (S, E, H, R)
            shared_component = torch.matmul(x.unsqueeze(-2), self.shared_weight).squeeze(-2)  # (S, E, H, R)

        # Factorized component calculation
        # shared_component (S, E, H, R) @ expert_weights (E, H, R, O) -> factorized_component (S, E, H, O)
        factorized_component = torch.matmul(shared_component.unsqueeze(-2), self.expert_weights.unsqueeze(0)).squeeze(-2)  # (S, E, H, O)

        # Add bias (E, H, O) -> broadcasted to (S, E, H, O)
        output = factorized_component + self.bias
        output = rearrange(output, 's e h d -> e s h d')
        if was_unbatched:
            output = output.squeeze(0)
        return output

class LinearLayer(nn.Module):

    def __init__(self, in_dim, out_dim, rank, num_experts, num_heads, share_weights, **kwargs):
        super(LinearLayer, self).__init__()

        # Expert-specific weights and biases
        self.expert_weights = nn.Parameter(torch.randn(num_experts, num_heads, in_dim, out_dim))  # (E, H, R, O)
        nn.init.kaiming_normal_(self.expert_weights)
        self.bias = nn.Parameter(torch.zeros(num_experts, num_heads, out_dim))  # (E, H, O)
        nn.init.constant_(self.bias, 0)

    def forward(self, x):
        # x is shaped (S, E, H, D), where:
        #  E -> number of experts
        #  S -> batch size
        #  H -> number of heads
        #  D -> input dimension per head
        was_unbatched = len(x.shape) == 3
        if was_unbatched:
            x = x.unsqueeze(0)
        x = rearrange(x, 'e s h d -> s e h d')

        # Factorized component calculation
        # shared_component (S, E, H, R) @ expert_weights (E, H, R, O) -> factorized_component (S, E, H, O)
        factorized_component = torch.matmul(x.unsqueeze(-2), self.expert_weights.unsqueeze(0)).squeeze(-2)  # (S, E, H, O)

        # Add bias (E, H, O) -> broadcasted to (S, E, H, O)
        output = factorized_component + self.bias
        output = rearrange(output, 's e h d -> e s h d')
        if was_unbatched:
            output = output.squeeze(0)
        return output

class Experts(nn.Module):
    def __init__(self, in_features, out_features, num_heads, num_experts, lora_rank, dropout=0.1, 
                 lora_method='linear', share_weights=False):
        """
        Args:
            in_features (int): Total input feature dimension D.
            out_features (int): Output feature dimension for each head.
            num_heads (int): Number of segments (heads).
            num_experts (int): Number of experts for weight sharing using LoRA.
            lora_rank (int): Rank for LoRA low-rank updates.
            dropout (float): Dropout probability.
        """
        super(Experts, self).__init__()

        # Check that the input dimension is divisible by the number of heads
        assert in_features % num_heads == 0, "in_features must be divisible by num_heads"
        assert out_features % num_heads == 0, "out_features must be divisible by num_heads"
        # Dimension per head
        self.segment_dim = in_features // num_heads
        self.num_heads = num_heads
        self.out_features = out_features // num_heads
        self.num_experts = num_experts

        # Create LoRA-based shared weight layer for each head (done in parallel)
        if 'lora' in lora_method.lower():
            nn_module = LoRA
        elif 'dense_linear' in lora_method.lower():
            nn_module = LinearLayer
        elif 'linear' in lora_method.lower():
            nn_module = FactorizedLinear
        else:
            raise ValueError(f"Unknown LoRA method: {lora_method}: select from ['lora', 'dense_linear', 'linear']")

        self.lora_layers = nn_module(self.segment_dim, self.out_features, lora_rank, num_experts,
                                     num_heads, share_weights)
        self.activation = F.relu
        self.norm = nn.ModuleList([RMSNorm(self.out_features) for _ in range(num_heads)])
        self.bias = nn.Parameter(torch.zeros(num_experts, num_heads, self.out_features))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (e s h d)

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, num_heads, out_features).
        """
        # Get shape dimensions
        batch_size, num_experts, num_heads, segment_dim = x.shape
        assert num_heads == self.num_heads and segment_dim == self.segment_dim, "Input shape mismatch"

        # Apply the LoRA layer to all heads in parallel
        out = self.lora_layers(x)  # Shape: (b e h d)
        if self.bias is not None:
            out = out + self.bias

        out = torch.cat([self.norm[i](out[:, :, i, :]).unsqueeze(2) for i in range(self.num_heads)], dim=2)
        out = self.activation(out)
        out = self.dropout(out)
        return out

class Mammoth(nn.Module):
    def __init__(
        self,
        input_dim,
        dim,
        *,
        num_experts=30,
        num_slots=10,  # num slots per expert. larger num means greater computational load
        keep_slots=True,  # if true, keep slot representations instead of  redistributing to tokens
        num_heads=16,
        dropout=0.1,
        device=None,
        lora_rank=16,
        share_lora_weights=True,
        slot_dim=256,  # dimension of the slot embeddings: 
        query_activation='identity',  # swish may be better
        **kwargs
    ):
        super().__init__()
        assert num_experts >= 1, 'expected >1 experts to use MAMMOTH'
        #super().__init__(input_dim=dim, num_experts=num_experts, method='soft')
        self.input_dim = input_dim
        self.num_experts = num_experts
        norm_klass = LayerNorm
        self.num_heads = num_heads
        assert divisible_by(input_dim, num_heads), 'dimension must be divisible by number of heads'
        self.num_slots = num_slots
        self.keep_slots = keep_slots
        self.wq = nn.Linear(input_dim, slot_dim)
        # kaiming he init
        self.logit_activation = self.get_activation(query_activation)  # add this activation to logits

        self.head_dim = slot_dim // num_heads
        self.norm = norm_klass(slot_dim)
        self.slot_norm = norm_klass(self.head_dim)
        #  todo: could make these embeds into a linear layer which should be faster
        self.slot_embeds = nn.Parameter(torch.randn(num_experts, num_heads, num_slots, self.head_dim))
        # orthogonal
        self.lora_rank = lora_rank
        self.expert_heads = Experts(slot_dim, dim, num_heads, num_experts, lora_rank, dropout, share_weights=share_lora_weights).to(device)


    def get_query_fn(self, input_dim, dim):
        expert_input_dim = dim
        self.wq = nn.Linear(input_dim, expert_input_dim, bias=True)
        # xavier init
        for name, param in self.wq.named_parameters():
            if 'weight' in name:
                nn.init.xavier_uniform_(param)

        return expert_input_dim, self.wq


    def get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'gelu':
            return nn.GELU()
        elif activation == 'tanh':
            return nn.Tanh()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        elif activation == 'swish':
            return nn.SiLU()
        elif activation in ['none', '', 'identity']:
            return nn.Identity()
        raise ValueError(f'Identity function {activation} not recognized')


    @staticmethod
    def apply_expert_heads(slot: torch.Tensor, norm: torch.nn.Module, expert: torch.nn.Module,
                           use_residual: bool) -> torch.Tensor:
        out = expert(slot)
        if use_residual:
            out = out + slot
        if norm is None:
            return out
        return norm(out)


    def apply_wq(self, feats):
        return self.norm(self.wq(feats))


    def forward(self, feats):
        """
        einstein notation
        b - batch
        n - sequence length
        e - number of experts
        s - number of slots per expert
        d - feature dimension
        """

        feats, _ = ensure_batched(feats)
        loss_dict = {}
        tokens = []
        x = self.apply_wq(feats)  # (b, n, i) -> (b, n, d)  # input to hidden dim
        x = rearrange(x, 'b n (h d) -> b n h d', h=self.num_heads)
        b, n, h, d = x.shape
        logits = self.get_logits(x)  # logits= b n e h s
        combine_weights, dispatch_weights = self.get_weights(logits) 
        slots = einsum('b n h d, b n e h s -> b e h s d', x, dispatch_weights)
        slots = rearrange(slots, 'b e h s d -> e (b s) h d')

        if len(tokens) > 0:  # join the latent tokens back to the feats
            slots = self.join_feats_at_index(slots, tokens, method='end')

        out = self.expert_heads(slots)  # e s h d
        out = rearrange(out, '(e b) s h d -> (e h) b s d', b=b)  # return to our standard shape

        if not isinstance(out, list):
            out = rearrange(out, '(e h) b s d -> b e s (h d)', e=self.num_experts)
            # combine back out E sequence outputs to original sequence length
            out = rearrange(out, ' b e s d -> b (e s) d')
        else:
            # out is e x (bh s d) where s is different for each expert.
            # we want to combine along expert dimension
            out = torch.cat(out, dim=1) # b (e s) d
            out = rearrange(out, 'b (s h) d -> b s (h d)', h=self.num_heads)

        # out shape here: b (e s) (h d)
        if not self.keep_slots:
            out = rearrange(out, 'b s (h d) -> b h s d', h=self.num_heads)
            out = einsum('b h p d, b n h p -> b n h d', out, combine_weights)
            out = rearrange(out, 'b n h d -> b n (h d)')

        return out


    def get_logits(self, x):
        """
        x: b n h d
        slot_embeds: e h s d
        """
        slot_embeds = self.slot_norm(self.slot_embeds)  # e, h, s, d
        x_active = self.logit_activation(x)
        logits = einsum('b n h d, e h s d -> b n e h s', x_active, slot_embeds)
        # scaled logits
        #logits = logits / (self.head_dim ** 0.5)
        return logits

    def get_weights(self, logits):
        # logits: b n e h s
        dispatch_weights = logits.softmax(dim=1)
        combine_weights = rearrange(logits, 'b n e h s -> b n h (e s)').softmax(dim=-1)

        return combine_weights, dispatch_weights

if __name__ == '__main__':
    model_kwargs = {
        'input_dim': 1024,
        'dim': 512,
        'num_experts': 30,
        'num_slots': 1,
        'num_heads': 16,
        'lora_rank': 13,
        'lora_method': 'linear',
        'share_lora_weights': True,
        'slot_dim': 256,
    }
    device = 'cuda:0'
    mammoth = Mammoth(**model_kwargs)
    x = torch.randn(1, 10000, 1024)
    #x = x.to(device)
    #mammoth = mammoth.to(device)
    print(mammoth(x).shape)
    num_iters = 100


    # --- GPU Usage ---
    def get_gpu_utilization():
        try:
            gpus = GPU.getGPUs()
            if gpus:
                return gpus[0].load * 100
            else:
                return None
        except Exception as e:
            print(f"Error getting GPU utilization: {e}")
            return None

    # --- Latency Measurement ---
    latencies = []
    num_warmup = 10
    with torch.no_grad():
        for i in range(num_iters + num_warmup):
            start_time = time.perf_counter()
            _ = mammoth(x)
            end_time = time.perf_counter()
            if i >= num_warmup:
                latencies.append(end_time - start_time)

    avg_latency_ms = (sum(latencies) / len(latencies)) * 1000

    # --- FLOPs Measurement (using fvcore) ---
    try:
        flops_analyzer = FlopCountAnalysis(mammoth, x)
        total_flops = flops_analyzer.total()
        print(f"Estimated FLOPs: {total_flops / 1e9:.2f} GFLOPs")
    except Exception as e:
        print(f"Error calculating FLOPs with fvcore: {e}. Make sure fvcore is installed (`pip install fvcore`).")
        total_flops = None


    # --- Print Results ---
    avg_gpu_utilization = get_gpu_utilization()
    if avg_gpu_utilization is not None:
        print(f"Average GPU Utilization: {avg_gpu_utilization:.2f}%")
    print(f"Average Latency per iteration: {avg_latency_ms:.2f} ms")


