import torch
import torch.nn as nn
import math
from typing import List, Optional, Tuple
import torch.nn.functional as F
from ptflops import get_model_complexity_info


# -------------------------
# Patch embedding (unchanged)
# -------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=384):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


# -------------------------
# Custom single-head MHSA core
# -------------------------
class CustomMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=6, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward_qkv(self, x):
        """
        return Q, K, V shaped [B, heads, N, head_dim]
        """
        B, N, C = x.shape
        q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        return q, k, v

    def forward_from_qkv(self, q, k, v, attn_probs: Optional[torch.Tensor] = None):
        """
        given q,k,v with shape [B,heads,N,hd], optionally given attn_probs [B,heads,N,N],
        compute final output [B,N,C]. This helper allows reuse when we have precomputed probs.
        """
        if attn_probs is None:
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_probs = self.dropout(attn_probs)
        out = torch.matmul(attn_probs, v)  # [B,heads,N,hd]
        out = out.transpose(1, 2).contiguous().view(q.shape[0], q.shape[2], self.embed_dim)  # [B,N,C]
        out = self.out_proj(out)
        return out, attn_probs


# -------------------------
# Multi-branch MHSA with progressive lambda-weighting
# -------------------------
class LambdaMultiBranchMHSA(nn.Module):
    """
    Multi-branch MHSA with progressive lambda mixing.
    Now supports lambda_matrix for controlling which branches mix with others.
    """

    def __init__(self, embed_dim, num_heads=6, num_branches=2, dropout=0.0,
                 temperature_mode="adaptive", base_temperature=1.0):
        super().__init__()
        self.num_branches = num_branches
        self.temperature_mode = temperature_mode
        self.base_temperature = base_temperature

        self.branches: nn.ModuleList = nn.ModuleList([
            CustomMultiHeadAttention(embed_dim, num_heads=num_heads, dropout=dropout)
            for _ in range(num_branches)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.logits_norm = nn.LayerNorm(self.branches[0].num_heads)

        # Learnable temperature parameter
        if temperature_mode == "learnable":
            self.temperature = nn.Parameter(torch.tensor(base_temperature))

    def _get_temperature(self, avg_lambda: float):
        """Calculate temperature based on the chosen mode"""
        if self.temperature_mode == "fixed":
            return self.base_temperature
        elif self.temperature_mode == "adaptive":
            # Adaptive temperature based on average lambda
            return self.base_temperature * (1.0 + avg_lambda * (self.num_branches - 1.0))
        elif self.temperature_mode == "learnable":
            return torch.clamp(self.temperature, min=0.1, max=10.0)
        else:
            return self.base_temperature

    @staticmethod
    def _pairwise_diversity(feats_list: List[torch.Tensor]) -> torch.Tensor:
        """
        feats_list: list of tensors with shape [B, N, D] (or [B, D])
        returns: scalar tensor (mean over batch/tokens/pairs)
        """
        n = len(feats_list)
        if n < 2:
            return feats_list[0].new_tensor(0.0)
        div = 0.0
        cnt = 0
        for i in range(n):
            for j in range(i + 1, n):
                sim = F.cosine_similarity(feats_list[i], feats_list[j], dim=-1)
                div = div + sim.pow(2).mean()
                cnt += 1
        return div / cnt

    def forward(self, x, lambda_matrix: torch.Tensor, get_diversity_loss=False):
        """
        x: [B,N,D]
        lambda_matrix: [num_branches, num_branches] matrix where lambda_matrix[i,j]
                      controls how much branch j influences branch i
        """
        x_ = self.norm(x)
        # compute per-branch q,k,v and logits
        q_list, k_list, v_list = [], [], []
        Lj_list = []
        for branch in self.branches:
            q, k, v = branch.forward_qkv(x_)
            q_list.append(q)
            k_list.append(k)
            v_list.append(v)
            Lj = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(branch.head_dim)
            Lj_list.append(Lj)

        # Get average lambda for temperature calculation
        avg_lambda = lambda_matrix.mean().item()
        temperature = self._get_temperature(avg_lambda)

        outs = []
        attn_probs_list = []

        # For each branch i, compute weighted logits based on lambda_matrix
        for i in range(self.num_branches):
            Li = Lj_list[i]
            L_i_weighted = Li.clone()  # Start with own logits
            scale = 1.0

            # Add contributions from other branches
            for j in range(self.num_branches):
                if i != j:
                    lambda_ij = lambda_matrix[i, j]
                    L_i_weighted += lambda_ij * Lj_list[j]
                    scale += lambda_ij ** 2

                    # Apply temperature scaling
            L_i_scaled = L_i_weighted / scale
            attn_probs = torch.softmax(L_i_scaled, dim=-1)

            out_i, _ = self.branches[i].forward_from_qkv(q_list[i], k_list[i], v_list[i], attn_probs=attn_probs)
            outs.append(out_i)
            attn_probs_list.append(attn_probs)

        # Average outputs across branches
        out = sum(outs) / float(self.num_branches)
        if get_diversity_loss:
            div_attn = self._pairwise_diversity(outs)
            return out, Lj_list, attn_probs_list, div_attn
        else:
            return out, Lj_list, attn_probs_list


# -------------------------
# Multi-branch MLP with progressive lambda mixing
# -------------------------
class LambdaMultiBranchMLP(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, num_branches=2, dropout=0.0):
        super().__init__()
        self.num_branches = num_branches
        hidden = int(embed_dim * mlp_ratio)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(embed_dim, hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, embed_dim),
                nn.Dropout(dropout),
            ) for _ in range(num_branches)
        ])

    @staticmethod
    def _pairwise_diversity(feats_list: List[torch.Tensor]) -> torch.Tensor:
        n = len(feats_list)
        if n < 2:
            return feats_list[0].new_tensor(0.0)
        div = 0.0
        cnt = 0
        for i in range(n):
            for j in range(i + 1, n):
                sim = F.cosine_similarity(feats_list[i], feats_list[j], dim=-1)
                div = div + sim.pow(2).mean()
                cnt += 1
        return div / cnt

    def forward(self, x, lambda_matrix: torch.Tensor, get_diversity_loss=False):
        ys = []
        hiddens = []

        # Compute hidden representations for all branches
        for b in self.branches:
            norm = b[0]
            fc1 = b[1]
            h = fc1(norm(x))
            hiddens.append(h)

        # Mix hidden representations based on lambda_matrix
        mixed_hiddens = []
        for i in range(self.num_branches):
            h_mixed = hiddens[i].clone()  # Start with own hidden

            # Add contributions from other branches
            for j in range(self.num_branches):
                if i != j:
                    lambda_ij = lambda_matrix[i, j]
                    h_mixed = h_mixed + lambda_ij * hiddens[j]

            mixed_hiddens.append(h_mixed)

        # Continue through rest of MLP
        for i, b in enumerate(self.branches):
            act = b[2]  # GELU
            drop1 = b[3]  # Dropout
            fc2 = b[4]  # Linear(hidden -> embed_dim)
            drop2 = b[5]  # Dropout

            y = drop2(fc2(drop1(act(mixed_hiddens[i]))))
            ys.append(y)

        out = sum(ys) / float(self.num_branches)
        if get_diversity_loss:
            div_mlp = self._pairwise_diversity(ys)
            return out, ys, div_mlp
        else:
            return out, ys


# -------------------------
# Transformer block with progressive lambda-controlled branches
# -------------------------
class ParallelTransformerBlockWithLambda(nn.Module):
    def __init__(self, embed_dim=384, num_heads=6, mlp_ratio=4.0,
                 dropout=0.0, attn_branches=2, mlp_branches=2,
                 temperature_mode="adaptive", base_temperature=1.0, drop_path=0.1,
                 attn_div_weight: float = 1.0, mlp_div_weight: float = 1.0):
        super().__init__()
        self.attn = LambdaMultiBranchMHSA(embed_dim, num_heads=num_heads,
                                          num_branches=attn_branches, dropout=dropout,
                                          temperature_mode=temperature_mode,
                                          base_temperature=base_temperature)
        self.mlp = LambdaMultiBranchMLP(embed_dim, mlp_ratio=mlp_ratio,
                                        num_branches=mlp_branches, dropout=dropout)
        self.drop_path = nn.Identity()
        self.attn_div_weight = attn_div_weight
        self.mlp_div_weight = mlp_div_weight

    def forward(self, x, attn_lambda_matrix: torch.Tensor, mlp_lambda_matrix: torch.Tensor, get_diversity_loss=False):
        if get_diversity_loss:
            attn_out, _, _, div_attn = self.attn(x, attn_lambda_matrix, get_diversity_loss)
            x = x + attn_out
            mlp_out, _, div_mlp = self.mlp(x, mlp_lambda_matrix, get_diversity_loss)
            x = x + mlp_out
            block_div_loss = self.attn_div_weight * div_attn + self.mlp_div_weight * div_mlp / 2
            return x, block_div_loss
        else:
            x = x + self.attn(x, attn_lambda_matrix)[0]
            x = x + self.mlp(x, mlp_lambda_matrix)[0]
            return x


# -------------------------
# Custom ModuleList for transformer blocks
# -------------------------
class LambdaTransformerBlocks(nn.Module):
    """
    Custom wrapper to handle multiple transformer blocks that need lambda matrices
    """

    def __init__(self, blocks):
        super().__init__()
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x, attn_lambda_matrix: torch.Tensor, mlp_lambda_matrix: torch.Tensor, get_diversity_loss=False):
        if get_diversity_loss:
            total_div = 0.0
            for block in self.blocks:
                x, div_loss = block(x, attn_lambda_matrix, mlp_lambda_matrix, get_diversity_loss)
                total_div += div_loss
            return x, total_div / len(self.blocks)
        else:
            for block in self.blocks:
                x = block(x, attn_lambda_matrix, mlp_lambda_matrix)
            return x


# -------------------------
# Main ViT model
# -------------------------
class ParallelViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=200,
                 embed_dim=192, depth=6, num_heads=12, mlp_ratio=4.0, dropout=0.1,
                 attn_branches=2, mlp_branches=2, temperature_mode="adaptive",
                 base_temperature=1.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches
        self.attn_branches = attn_branches
        self.mlp_branches = mlp_branches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        blocks = [
            ParallelTransformerBlockWithLambda(embed_dim, num_heads, mlp_ratio, dropout,
                                               attn_branches=attn_branches, mlp_branches=mlp_branches,
                                               temperature_mode=temperature_mode,
                                               base_temperature=base_temperature)
            for _ in range(depth)
        ]
        self.blocks = LambdaTransformerBlocks(blocks)

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, attn_lambda_matrix=None, mlp_lambda_matrix=None, get_diversity_loss=False):
        B = x.shape[0]

        # Create default identity matrices if not provided
        if isinstance(attn_lambda_matrix, (int, float)):
            if attn_lambda_matrix == 0:
                attn_lambda_matrix = torch.eye(self.attn_branches, device=x.device)
                mlp_lambda_matrix = torch.eye(self.mlp_branches, device=x.device)
            elif attn_lambda_matrix == 1:
                attn_lambda_matrix = torch.ones(self.attn_branches, self.attn_branches, device=x.device)
                mlp_lambda_matrix = torch.ones(self.mlp_branches, self.mlp_branches, device=x.device)

        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        if get_diversity_loss:
            x, diversity_loss = self.blocks(x, attn_lambda_matrix, mlp_lambda_matrix, get_diversity_loss)
            x = self.norm(x)
            return self.head(x[:, 1:].mean(dim=1)), diversity_loss
        else:
            x = self.blocks(x, attn_lambda_matrix, mlp_lambda_matrix)
            x = self.norm(x)
            return self.head(x[:, 1:].mean(dim=1))


def get_parallel_vitv2(num_classes=200, attn_branches=2, mlp_branches=2, depth=6,
                     temperature_mode="adaptive", base_temperature=1.0, dropout=0.0):
    return ParallelViT(num_classes=num_classes,
                       attn_branches=attn_branches,
                       mlp_branches=mlp_branches,
                       depth=depth,
                       temperature_mode=temperature_mode,
                       base_temperature=base_temperature,
                       dropout=dropout)


# -------------------------
# Progressive Lambda scheduler
# -------------------------
class ProgressiveLambdaScheduler:
    """
    Progressive lambda scheduler that merges branches sequentially.
    For n branches, it creates n-1 merge phases, where branch i merges with others
    in phase i-1.

    Returns lambda matrices where lambda_matrix[i,j] controls how much branch j
    influences branch i.
    """

    def __init__(self, num_branches: int, warmup_steps: int, mode: str = "linear",
                 merge_order: Optional[List[int]] = None):
        assert num_branches >= 2
        assert warmup_steps >= 1
        assert mode in ("linear", "cosine", "exponential", "sqrt", "sine", "smoothstep")

        self.num_branches = num_branches
        self.warmup_steps = warmup_steps
        self.mode = mode

        # Default merge order: 0->1, then 2->01, then 3->012, etc.
        if merge_order is None:
            self.merge_order = list(range(num_branches))
        else:
            assert len(merge_order) == num_branches
            assert set(merge_order) == set(range(num_branches))
            self.merge_order = merge_order

        # Calculate steps for each merge phase
        self.steps_per_phase = warmup_steps // (num_branches - 1)
        self.phase_starts = [i * self.steps_per_phase for i in range(num_branches - 1)]
        self.phase_ends = [(i + 1) * self.steps_per_phase for i in range(num_branches - 1)]
        self.phase_ends[-1] = warmup_steps  # Ensure last phase ends exactly at warmup_steps

    def _get_lambda_value(self, t: float) -> float:
        """Get lambda value for normalized time t in [0,1]"""
        if t <= 0:
            return 0.0
        if t >= 1.0:
            return 1.0

        if self.mode == "linear":
            return float(t)
        elif self.mode == 'cosine':
            return float(0.5 * (1 - math.cos(math.pi * t)))
        elif self.mode == 'exponential':
            return 1 - math.exp(-5 * t)
        elif self.mode == "sqrt":
            return float(math.sqrt(t))
        elif self.mode == "sine":
            return float(math.sin(0.5 * math.pi * t))
        elif self.mode == "smoothstep":
            return float(3 * t ** 2 - 2 * t ** 3)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

    def get_lambda_matrices(self, step: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns (attn_lambda_matrix, mlp_lambda_matrix) for the given step.
        Both matrices are [num_branches, num_branches] where element [i,j]
        controls how much branch j influences branch i.

        Merging logic: When branches merge, they influence each other bidirectionally.
        - Phase 0: Branch 0 ↔ Branch 1 (mutual influence)
        - Phase 1: {Branch 0, Branch 1} ↔ Branch 2 (Branch 2 exchanges with both 0 and 1)
        - Phase 2: {Branch 0, Branch 1, Branch 2} ↔ Branch 3 (Branch 3 exchanges with 0,1,2)

        This creates a symmetric pattern in the active region.
        """
        # Initialize identity matrices (diagonal = 1, off-diagonal = 0)
        attn_matrix = torch.eye(self.num_branches, dtype=torch.float32)
        mlp_matrix = torch.eye(self.num_branches, dtype=torch.float32)

        if step <= 0:
            return attn_matrix, mlp_matrix

        # For each completed or in-progress phase, set appropriate lambda values
        for phase_idx in range(self.num_branches - 1):
            phase_start = self.phase_starts[phase_idx]
            phase_end = self.phase_ends[phase_idx]

            # Current merging branch (the new one joining the group)
            merging_branch_idx = phase_idx + 1
            merging_branch = self.merge_order[merging_branch_idx]

            # Existing group that the new branch is joining
            existing_group_indices = list(range(phase_idx + 1))  # [0, 1, ..., phase_idx]
            existing_group = [self.merge_order[i] for i in existing_group_indices]

            # Determine lambda value for this phase
            if step >= phase_end:
                # This phase is complete
                lambda_val = 1.0
            elif step >= phase_start:
                # This phase is in progress
                phase_progress = (step - phase_start) / (phase_end - phase_start)
                lambda_val = self._get_lambda_value(phase_progress)
            else:
                # This phase hasn't started yet, skip
                continue

            # Apply bidirectional lambda values
            for existing_branch in existing_group:
                # Existing branch receives from merging branch
                attn_matrix[existing_branch, merging_branch] = lambda_val
                mlp_matrix[existing_branch, merging_branch] = lambda_val

                # Merging branch receives from existing branch
                attn_matrix[merging_branch, existing_branch] = lambda_val
                mlp_matrix[merging_branch, existing_branch] = lambda_val

        return attn_matrix, mlp_matrix

    def get_current_phase_info(self, step: int) -> dict:
        """Return information about the current merge phase"""
        for phase_idx in range(self.num_branches - 1):
            phase_start = self.phase_starts[phase_idx]
            phase_end = self.phase_ends[phase_idx]

            if phase_start <= step < phase_end:
                progress = (step - phase_start) / (phase_end - phase_start)
                return {
                    'phase': phase_idx,
                    'merging_branch': self.merge_order[phase_idx + 1],
                    'target_branches': self.merge_order[:phase_idx + 1],
                    'progress': progress,
                    'lambda_value': self._get_lambda_value(progress)
                }

        # All phases complete
        return {
            'phase': self.num_branches - 1,
            'merging_branch': None,
            'target_branches': self.merge_order,
            'progress': 1.0,
            'lambda_value': 1.0
        }


# -------------------------
# Example usage
# -------------------------
def example_usage():
    """Example of how to use the progressive merger"""

    # Create model with 4 branches
    model = get_parallel_vitv2(num_classes=200, attn_branches=4, mlp_branches=4, depth=6)

    # Create scheduler for 4 branches with 1000 warmup steps
    scheduler = ProgressiveLambdaScheduler(num_branches=4, warmup_steps=1000, mode="cosine")

    # Example training loop
    for step in [0, 100, 333, 500, 666, 800, 1000]:
        # Get lambda matrices for current step
        attn_lambda, mlp_lambda = scheduler.get_lambda_matrices(step)

        # Get current phase info
        phase_info = scheduler.get_current_phase_info(step)

        print(f"Step {step}: Phase {phase_info['phase']}, Progress: {phase_info['progress']:.3f}")
        print(f"Attention Lambda Matrix:\n{attn_lambda}")
        print(f"MLP Lambda Matrix:\n{mlp_lambda}")
        print("-" * 50)

        # Forward pass with lambda matrices
        dummy_input = torch.randn(2, 3, 224, 224)
        output = model(dummy_input, attn_lambda_matrix=attn_lambda, mlp_lambda_matrix=mlp_lambda)


if __name__ == "__main__":
    example_usage()