import torch
import sys
import copy
from pathlib import Path
from torch_uncertainty.post_processing import TemperatureScaler
from torch.utils.data import random_split, DataLoader
from torch_uncertainty.routines import ClassificationRoutine
import torch.nn as nn
import torch.nn.functional as F
from torch_uncertainty import TUTrainer
from torch_uncertainty.datamodules import ImageNetDataModule
from prettytable import PrettyTable
import json
from pathlib import Path
from collections import defaultdict

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from Pruning.utils.prune_model.structured import structurally_prune_attention_heads

from Pruning.models.block import (
    prune_heads,
    unprune_heads,
    get_active_heads,
    prune_mlp,
    unprune_mlp,
    get_active_mlp,
)
from torchvision.models.vision_transformer import VisionTransformer
from torchvision.models import vit_b_16
try:
    from torchvision.models import ViT_B_16_Weights
except ImportError:  
    ViT_B_16_Weights = None
from typing import Dict, Union, Optional, List, Tuple

VisionTransformer.prune_heads       = prune_heads
VisionTransformer.unprune_heads     = unprune_heads
VisionTransformer.get_active_heads  = get_active_heads
VisionTransformer.prune_mlp         = prune_mlp
VisionTransformer.unprune_mlp       = unprune_mlp
VisionTransformer.get_active_mlp    = get_active_mlp


def print_detailed_vit_summary(vit_model):
    table = PrettyTable()
    table.field_names = ["Layer", "# Heads", "QKV Shape", "Out Shape", "MLP Shapes", "Total Params (M)"]

    for i, block in enumerate(vit_model.encoder.layers):
        mha = block.self_attention
        mlp = block.mlp

        embed_dim = mha.embed_dim
        num_heads = mha.num_heads

        qkv_shape = f"{3 * embed_dim} x {embed_dim}"
        out_shape = f"{embed_dim} x {embed_dim}"
        mlp_shapes = f"{mlp[0].weight.shape[0]} x {mlp[0].weight.shape[1]}, {mlp[3].weight.shape[0]} x {mlp[3].weight.shape[1]}"

        total_params = sum(p.numel() for p in mha.parameters()) + sum(p.numel() for p in mlp.parameters())
        table.add_row([i, num_heads, qkv_shape, out_shape, mlp_shapes, f"{total_params/1e6:.2f}"])

    print("\nViT Layer-wise Summary:")
    print(table)

    
def load_model(ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location='cpu',weights_only=False)
    state = ckpt.get('state_dict', ckpt)

    new_state = {}
    for k, v in state.items():
        name = k
        if k.startswith('model.'):
            name = k[len('model.'):]
        new_state[name] = v    
    
    renamed = {}
    for k, v in new_state.items():
        if k == 'heads.weight':
            renamed['heads.head.weight'] = v
        elif k == 'heads.bias':
            renamed['heads.head.bias'] = v
        else:
            renamed[k] = v
    
    model = vit_b_16(weights=None, num_classes=1000, image_size=224)
    model.load_state_dict(renamed, strict=True)

    model.eval().to(device)
    return model



def load_zero_shot(
    ckpt_path: str,
    device: torch.device,
    num_heads_to_prune: Union[int, Dict[int, list[int]]],
    prune_strategy: str,
    prune_loader: DataLoader = None,
):
    ckpt = torch.load(ckpt_path, map_location='cpu',weights_only=False)
    state = ckpt.get('state_dict', ckpt)

    new_state = {}
    for k, v in state.items():
        name = k
        if k.startswith('model.'):
            name = k[len('model.'):]
        new_state[name] = v    
    
    renamed = {}
    for k, v in new_state.items():
        if k == 'heads.weight':
            renamed['heads.head.weight'] = v
        elif k == 'heads.bias':
            renamed['heads.head.bias'] = v
        else:
            renamed[k] = v
    
    model = vit_b_16(weights=None, num_classes=1000, image_size=224)
    model.load_state_dict(renamed, strict=True)

    model = structurally_prune_attention_heads(
            model,
            num_heads_to_prune=num_heads_to_prune,
            strategy=prune_strategy,
            dataloader=prune_loader,
            device=device,
        )
    print_detailed_vit_summary(model)

    model.eval().to(device)
    return model




def load_single_model(
    ckpt_path: Optional[str],
    device: torch.device,
    num_heads_to_prune: Union[int, Dict[int, list[int]]],
    prune_strategy: str,
    prune_loader: DataLoader = None,
):
    load_from_torchvision = ckpt_path is None or ckpt_path == "" or str(ckpt_path).startswith("torchvision://")

    weights = None
    if load_from_torchvision and ViT_B_16_Weights is not None:
        requested = str(ckpt_path or "").replace("torchvision://", "")
        weight_name = "IMAGENET1K_V1" if requested in ["", "vit_b_16_imagenet1k"] else requested
        weights = getattr(ViT_B_16_Weights, weight_name, None)
        if weights is None:
            fallback_weights = [name for name in ("IMAGENET1K_SWAG_E2E_V1", "IMAGENET1K_SWAG_LINEAR_V1") if hasattr(ViT_B_16_Weights, name)]
            weights = getattr(ViT_B_16_Weights, fallback_weights[0], None) if fallback_weights else None

    model = vit_b_16(weights=weights, num_classes=1000, image_size=224)
    model.to(device)
    model = structurally_prune_attention_heads(
            model,
            num_heads_to_prune=num_heads_to_prune,
            strategy=prune_strategy,
            dataloader=prune_loader,
            device=device,
        )
    print_detailed_vit_summary(model)


    if not load_from_torchvision:
        ck = torch.load(ckpt_path, map_location="cuda", weights_only=False)
        state = ck.get("state_dict", ck)

        new_state = {}
        for k, v in state.items():
            nk = k[len("model."):] if k.startswith("model.") else k
            if nk == "heads.weight":
                nk = "heads.head.weight"
            elif nk == "heads.bias":
                nk = "heads.head.bias"
            new_state[nk] = v

        model.load_state_dict(new_state, strict=True)


    return model.eval().to(device)


def heads_from_log(json_path: str | Path) -> dict[int, list[int]]:
    path = Path(json_path)
    if not path.is_file():
        candidate = Path(__file__).resolve().parent / path
        if candidate.is_file():
            path = candidate
    with path.open() as f:
        history = json.load(f)

    pruned = defaultdict(list)

    for entry in history:
        if "pruned" in entry:                      
            layer, head = entry["pruned"]
            pruned[layer].append(head)

    for layer in pruned:
        pruned[layer] = sorted(pruned[layer])

    return dict(pruned)


# ──────────────────────────────────────────────────────────────────────────────
# Hydra ensemble (GFC) utilities
# ──────────────────────────────────────────────────────────────────────────────

def average_modules(mods: List[nn.Module]) -> nn.Module:
    """Average parameters of identical modules."""
    out = copy.deepcopy(mods[0]).cpu().float()
    with torch.no_grad():
        ref = mods[0].state_dict()
        mean_sd = {}
        for k, v in ref.items():
            if torch.is_floating_point(v):
                stacked = torch.stack([m.state_dict()[k].detach().cpu().float() for m in mods], 0)
                mean_sd[k] = stacked.mean(0)
            else:
                mean_sd[k] = v.detach().cpu()
        out.load_state_dict(mean_sd)
    return out


def copy_module_(dst: nn.Module, src: nn.Module):
    dst.load_state_dict(src.state_dict())


def split_qkv_from_inproj(in_proj_weight: torch.Tensor, in_proj_bias: Optional[torch.Tensor]):
    Dp3, d = in_proj_weight.shape
    assert Dp3 % 3 == 0
    Dp = Dp3 // 3
    Wq, Wk, Wv = in_proj_weight.split(Dp, dim=0)
    if in_proj_bias is None:
        return Wq, Wk, Wv, None, None, None
    bq, bk, bv = in_proj_bias.split(Dp, dim=0)
    return Wq, Wk, Wv, bq, bk, bv


class GFCLinearMultiImpl(nn.Module):
    def __init__(
        self,
        weight_blocks: List[torch.Tensor],
        bias_blocks: Optional[List[Optional[torch.Tensor]]],
        impl: str = "einsum",
    ):
        super().__init__()
        self.impl = impl
        self.M = len(weight_blocks)
        assert self.M >= 1, "Need at least one group"

        d_set = {int(w.shape[1]) for w in weight_blocks}
        assert len(d_set) == 1, "All W_m must have same input dim"
        self.d = int(weight_blocks[0].shape[1])

        self.out_sizes = [int(w.shape[0]) for w in weight_blocks]
        self.sum_out = int(sum(self.out_sizes))

        if bias_blocks is None:
            bias_blocks = [None] * self.M
        assert len(bias_blocks) == self.M

        self._W_blocks = [w.detach().float().contiguous() for w in weight_blocks]
        self._b_blocks = [
            (b.detach().float().contiguous() if (b is not None) else torch.zeros(w.shape[0], dtype=torch.float32))
            for w, b in zip(weight_blocks, bias_blocks)
        ]

        if impl == "full":
            with torch.no_grad():
                W_full = torch.block_diag(*self._W_blocks)
                b_full = torch.cat(self._b_blocks, dim=0)
            self.register_buffer("W_full", W_full.contiguous())
            self.register_buffer("b_full", b_full.contiguous())

        elif impl == "einsum":
            out_max = max(self.out_sizes)
            W_pad = torch.zeros((self.M, out_max, self.d), dtype=torch.float32)
            b_pad = torch.zeros((self.M, out_max), dtype=torch.float32)
            for m, (Wm, bm) in enumerate(zip(self._W_blocks, self._b_blocks)):
                om = Wm.shape[0]
                W_pad[m, :om, :] = Wm
                b_pad[m, :om] = bm
            self.out_max = out_max
            self.register_buffer("W_pad", W_pad.contiguous())
            self.register_buffer("b_pad", b_pad.contiguous())

        elif impl == "conv1d":
            Wc = torch.cat([w.view(w.shape[0], self.d, 1) for w in self._W_blocks], dim=0)
            bc = torch.cat(self._b_blocks, dim=0)
            self.groups = self.M
            self.register_buffer("Wc", Wc.contiguous())
            self.register_buffer("bc", bc.contiguous())

        elif impl == "csr":
            W_full = torch.block_diag(*self._W_blocks).to(dtype=torch.float32)
            W_csr = W_full.to_sparse_csr()
            b_full = torch.cat(self._b_blocks, dim=0)
            self.register_buffer("W_csr", W_csr)
            self.register_buffer("b_full", b_full.contiguous())
        else:
            raise ValueError(f"Unknown impl: {impl}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float()
        *lead, fin = x.shape
        expected = self.M * self.d
        assert fin == expected, f"GFC expected last dim {expected}, got {fin}"

        if self.impl == "full":
            y = F.linear(x.reshape(-1, fin), self.W_full, self.b_full)
            return y.view(*lead, self.sum_out)

        elif self.impl == "einsum":
            xb = x.reshape(-1, self.M, self.d)
            y_pad = torch.einsum("bmd,mod->bmo", xb, self.W_pad) + self.b_pad.unsqueeze(0)
            parts = [y_pad[:, m, : self.out_sizes[m]] for m in range(self.M)]
            y = torch.cat(parts, dim=1)
            return y.view(*lead, self.sum_out)

        elif self.impl == "conv1d":
            N = int(torch.tensor(lead).prod()) if len(lead) else 1
            xin = x.reshape(N, self.M * self.d, 1)
            y = F.conv1d(xin, self.Wc, self.bc, stride=1, padding=0, dilation=1, groups=self.groups)
            return y.reshape(*lead, self.sum_out)

        elif self.impl == "csr":
            Bstar = int(torch.tensor(lead).prod()) if len(lead) else 1
            X2 = x.reshape(Bstar, self.M * self.d).t().contiguous()
            Y2 = torch.matmul(self.W_csr, X2)
            y = Y2.t().contiguous() + self.b_full
            return y.view(*lead, self.sum_out)

        else:
            raise RuntimeError("Invalid impl in GFCLinearMultiImpl")


class FusedMHA_TokenFeature(nn.Module):
    def __init__(self, src_attns: List[nn.MultiheadAttention], gfc_impl: str = "einsum"):
        super().__init__()
        assert len(src_attns) == 3
        self.M = 3
        self.d = int(src_attns[0].embed_dim)

        if hasattr(src_attns[0], "head_dim"):
            self.dk = int(src_attns[0].head_dim)
        else:
            self.dk = self.d // int(src_attns[0].num_heads)

        self.Dp = []
        Wq_blocks, bq_blocks = [], []
        Wk_blocks, bk_blocks = [], []
        Wv_blocks, bv_blocks = [], []
        Wo_blocks, bo_blocks = [], []

        max_heads = 0
        for a in src_attns:
            Wi = a.in_proj_weight.detach().float()
            bi = a.in_proj_bias.detach().float() if a.in_proj_bias is not None else None
            Wq, Wk, Wv, bq, bk, bv = split_qkv_from_inproj(Wi, bi)
            assert Wq.shape[1] == self.d and Wk.shape[1] == self.d and Wv.shape[1] == self.d
            Dp_m = Wq.shape[0]
            assert Dp_m % self.dk == 0, "pruned rows must be multiple of head_dim"
            self.Dp.append(Dp_m)
            max_heads = max(max_heads, Dp_m // self.dk)

            Wq_blocks.append(Wq)
            Wk_blocks.append(Wk)
            Wv_blocks.append(Wv)
            bq_blocks.append(bq if bq is not None else torch.zeros(Dp_m))
            bk_blocks.append(bk if bk is not None else torch.zeros(Dp_m))
            bv_blocks.append(bv if bv is not None else torch.zeros(Dp_m))

            Wo_blocks.append(a.out_proj.weight.detach().float())
            bo_blocks.append(a.out_proj.bias.detach().float())

        if len(set(self.Dp)) != 1:
            target_rows = max_heads * self.dk
            def _pad_block(W, target):
                if W.shape[0] == target:
                    return W
                pad_rows = target - W.shape[0]
                return torch.cat([W, torch.zeros(pad_rows, W.shape[1], dtype=W.dtype, device=W.device)], dim=0)
            def _pad_bias(b, target, device):
                if b is None:
                    return torch.zeros(target, dtype=torch.float32, device=device)
                if b.shape[0] == target:
                    return b
                pad_rows = target - b.shape[0]
                return torch.cat([b, torch.zeros(pad_rows, dtype=b.dtype, device=b.device)], dim=0)
            def _pad_out_proj(W, target_cols):
                if W.shape[1] == target_cols:
                    return W
                pad_cols = target_cols - W.shape[1]
                if pad_cols < 0:
                    return W
                pad = torch.zeros(W.shape[0], pad_cols, dtype=W.dtype, device=W.device)
                return torch.cat([W, pad], dim=1)

            Wq_blocks = [_pad_block(W, target_rows) for W in Wq_blocks]
            Wk_blocks = [_pad_block(W, target_rows) for W in Wk_blocks]
            Wv_blocks = [_pad_block(W, target_rows) for W in Wv_blocks]
            Wo_blocks = [_pad_out_proj(W, target_rows) for W in Wo_blocks]
            bq_blocks = [_pad_bias(b, target_rows, W.device if b is None else b.device) for b, W in zip(bq_blocks, Wq_blocks)]
            bk_blocks = [_pad_bias(b, target_rows, W.device if b is None else b.device) for b, W in zip(bk_blocks, Wk_blocks)]
            bv_blocks = [_pad_bias(b, target_rows, W.device if b is None else b.device) for b, W in zip(bv_blocks, Wv_blocks)]
            bo_blocks = [_pad_bias(b, Wo.shape[0], Wo.device if b is None else b.device) for b, Wo in zip(bo_blocks, Wo_blocks)]
            self.Dp = [target_rows for _ in self.Dp]
            max_heads = target_rows // self.dk

        self.gfc_Q = GFCLinearMultiImpl(Wq_blocks, bq_blocks, impl=gfc_impl)
        self.gfc_K = GFCLinearMultiImpl(Wk_blocks, bk_blocks, impl=gfc_impl)
        self.gfc_V = GFCLinearMultiImpl(Wv_blocks, bv_blocks, impl=gfc_impl)
        self.gfc_O = GFCLinearMultiImpl(Wo_blocks, bo_blocks, impl=gfc_impl)

        self.Dp_total = sum(self.Dp)
        assert self.Dp_total % self.dk == 0
        self.H_total = self.Dp_total // self.dk

    def forward(self, X3_tokens: torch.Tensor) -> torch.Tensor:
        B, TT, d = X3_tokens.shape
        assert d == self.d and TT % 3 == 0
        T = TT // 3

        X_feat = X3_tokens.view(B, T, 3, d).reshape(B, T, 3 * d)

        Q = self.gfc_Q(X_feat)
        K = self.gfc_K(X_feat)
        V = self.gfc_V(X_feat)

        H, dk = self.H_total, self.dk
        Qh = Q.view(B, T, H, dk).transpose(1, 2).contiguous()
        Kh = K.view(B, T, H, dk).transpose(1, 2).contiguous()
        Vh = V.view(B, T, H, dk).transpose(1, 2).contiguous()
        C = F.scaled_dot_product_attention(Qh, Kh, Vh, attn_mask=None, dropout_p=0.0, is_causal=False)
        Z = C.transpose(1, 2).reshape(B, T, H * dk)

        Y_feat = self.gfc_O(Z)
        Y_tok = Y_feat.view(B, T, 3, d).reshape(B, 3 * T, d)
        return Y_tok


class FusedBlock_TokenStack(nn.Module):
    def __init__(self, src_layers: List[nn.Module], gfc_impl: str):
        super().__init__()
        self.ln1 = copy.deepcopy(src_layers[0].ln_1).float().eval()
        self.ln2 = copy.deepcopy(src_layers[0].ln_2).float().eval()
        self.mha = FusedMHA_TokenFeature([s.self_attention for s in src_layers], gfc_impl=gfc_impl)
        self.mlp = average_modules([s.mlp for s in src_layers]).float().eval()
        self.dropout = copy.deepcopy(src_layers[0].dropout).eval()

    def forward(self, X3_tok: torch.Tensor) -> torch.Tensor:
        B, TT, d = X3_tok.shape
        T = TT // 3
        Xhat_tok = self.ln1(X3_tok)

        X_res_feat = X3_tok.view(B, T, 3, d).reshape(B, T, 3 * d)
        Y_tok = self.mha(Xhat_tok)
        Y_feat = Y_tok.view(B, T, 3, d).reshape(B, T, 3 * d)
        Y_feat = Y_feat + X_res_feat
        Y_tok = Y_feat.view(B, T, 3, d).reshape(B, 3 * T, d)

        Yhat_tok = self.ln2(Y_tok)
        Z_tok = self.mlp(Yhat_tok)
        Z_tok = self.dropout(Z_tok)
        return Y_tok + Z_tok


class FusedEnsembleViT_Token(VisionTransformer):
    def __init__(self, ref: VisionTransformer, num_classes: int, num_layers: int, gfc_impl: str):
        mha0 = ref.encoder.layers[0].self_attention
        head_dim = getattr(mha0, "head_dim", mha0.embed_dim // max(1, mha0.num_heads))
        num_heads_safe = max(1, mha0.embed_dim // head_dim)
        super().__init__(
            image_size=getattr(ref, "image_size", 224),
            patch_size=getattr(ref, "patch_size", 16),
            num_layers=num_layers,
            num_heads=num_heads_safe,
            hidden_dim=getattr(ref, "hidden_dim", ref.encoder.layers[0].self_attention.embed_dim),
            mlp_dim=self._infer_mlp_dim(ref),
            num_classes=num_classes,
        )
        self._gfc_impl = gfc_impl
        self.M = 3

    @staticmethod
    def _infer_mlp_dim(model: VisionTransformer) -> int:
        mlp = model.encoder.layers[0].mlp
        first_linear = next(m for m in mlp.modules() if isinstance(m, nn.Linear))
        return first_linear.out_features

    def create_from_sources(self, sources: List[VisionTransformer], copy_global_from: int = 0) -> None:
        assert len(sources) == 3
        device = next(self.parameters()).device
        ref = sources[copy_global_from]

        with torch.no_grad():
            copy_module_(self.conv_proj, ref.conv_proj)
            self.class_token.copy_(ref.class_token)
            self.pos_embedding = nn.Parameter(ref.encoder.pos_embedding.detach().clone().float(), requires_grad=False)

        self.fused_layers = nn.ModuleList()
        num_layers = len(ref.encoder.layers)
        for i in range(num_layers):
            src_layers = [m.encoder.layers[i] for m in sources]
            self.fused_layers.append(FusedBlock_TokenStack(src_layers, gfc_impl=self._gfc_impl).to(device).eval())

        self.final_ln = copy.deepcopy(ref.encoder.ln).float().eval()
        self.heads_ensemble = nn.ModuleList([copy.deepcopy(m.heads).to(device).float().eval() for m in sources])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.float()
        B = x.shape[0]
        x = self._process_input(x)
        cls = self.class_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embedding

        B, T, d = x.shape
        X3_tok = x.unsqueeze(2).expand(B, T, 3, d).reshape(B, 3 * T, d)

        for blk in self.fused_layers:
            X3_tok = blk(X3_tok)
        X3_tok = self.final_ln(X3_tok)

        cls_list = [X3_tok[:, i, :] for i in range(self.M)]
        logits_list = [head(cls) for head, cls in zip(self.heads_ensemble, cls_list)]
        return torch.cat(logits_list, dim=0)

    def forward_probavg(self, x: torch.Tensor) -> torch.Tensor:
        logits_cat = self.forward(x)
        B = x.shape[0]
        C = logits_cat.shape[-1]
        logits_3 = logits_cat.view(self.M, B, C).transpose(0, 1)
        probs3 = torch.softmax(logits_3, dim=-1)
        return probs3.mean(dim=1)

    def forward_logitavg(self, x: torch.Tensor) -> torch.Tensor:
        logits_cat = self.forward(x)
        B = x.shape[0]
        C = logits_cat.shape[-1]
        logits_3 = logits_cat.view(self.M, B, C).transpose(0, 1)
        return logits_3.mean(dim=1)


def build_fused_ensemble_from_models_token(
    models: List[VisionTransformer],
    gfc_impl: str = "einsum",
    copy_global_from: int = 0,
) -> FusedEnsembleViT_Token:
    ref = models[copy_global_from]
    last_linear = None
    for m in ref.heads.modules():
        if isinstance(m, nn.Linear):
            last_linear = m
    assert last_linear is not None, "Could not infer num_classes"
    num_classes = last_linear.out_features

    model = FusedEnsembleViT_Token(
        ref,
        num_classes,
        num_layers=len(ref.encoder.layers),
        gfc_impl=gfc_impl,
    ).to(next(ref.parameters()).device).float()
    model.create_from_sources(models, copy_global_from=copy_global_from)
    return model


class HydraEnsembleClassifier(nn.Module):
    def __init__(self, fused: FusedEnsembleViT_Token):
        super().__init__()
        self.fused = fused

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits_cat = self.fused(x)
        B = x.shape[0]
        C = logits_cat.shape[-1]
        logits_3 = logits_cat.view(self.fused.M, B, C).transpose(0, 1)
        return logits_3.mean(dim=1)


def build_hydra_ensemble(
    models: List[VisionTransformer],
    gfc_impl: str = "einsum",
    copy_global_from: int = 0,
) -> HydraEnsembleClassifier:
    fused = build_fused_ensemble_from_models_token(models, gfc_impl=gfc_impl, copy_global_from=copy_global_from)
    return HydraEnsembleClassifier(fused)
