import math
import torch
import numpy as np
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import types
from typing import Dict, Callable, Optional
from timm.models.vision_transformer import VisionTransformer

from augment import DiffAugment
from utils.post_processing import quantile_clamp, PostProcessing
from utils.detach_mode import GELUDetached, _make_detach_attn_forward, _make_centering_layer_norm

from torch.func import vmap, jacrev


class DAVEExplainer:
    def __init__(self, vit_model: VisionTransformer, aug_cfg, postproc_cfg, device):
        self.model = vit_model.eval()
        self.detach_mode()

        self.aug = DiffAugment(aug_cfg).to(device)
        self.post_proc = PostProcessing(postproc_cfg).to(device)

    def get_noise_schedule(
        self,
        num_steps: int,
        device,
        alpha_start: float = 0.0,
        alpha_end: float = 0.3,
    ) -> Tensor:
        return torch.linspace(
            alpha_start,
            alpha_end,
            steps=num_steps,
            device=device,
        )

    def logit_reconstruction(
        self,
        x: Tensor,
        y: Tensor,
        k: int,
        num_steps: int,
        clamp: bool = False,
    ):

        self.model.eval()
        self.aug.train()

        with torch.no_grad():
            z = self.model(x)
            top_vals, top_idx = torch.topk(
                z, k=k, dim=1, largest=True, sorted=True,
            )

        y_flat = y.view(-1)
        pred_top1 = top_idx[:, 0]
        valid = (y_flat == pred_top1)

        x = x[valid]
        y_flat = y_flat[valid]
        top_vals = top_vals[valid]
        top_idx = top_idx[valid]

        if x.numel() == 0:
            return None, None

        N = x.shape[0]
        y_topk = top_idx

        c = self.attribute_supervised(
            x=x,
            y=y_topk,
            num_steps=num_steps,
            clamp=clamp,
            post_proc=False,
        )['attribution']
        rec_vals = c.sum(dim=(1, 2, 3))

        rec_vals = rec_vals.detach().cpu()
        top_vals = top_vals.detach().cpu()
        return top_vals, rec_vals

    def attribute_supervised(
        self,
        x: Tensor,
        y: Tensor,
        num_steps: int,
        clamp: bool = False,
        post_proc: bool = True,
    ) -> Dict[str, Tensor]:

        N = x.shape[0]
        device = x.device

        self.model.eval()
        self.aug.train()

        t_schedule = self.get_noise_schedule(
            num_steps=num_steps,
            device=device,
        )

        with torch.no_grad():
            z = self.model(x).detach()
            z = z.gather(dim=1, index=y.long())
            z = z.squeeze(-1)

        buffer = []
        weights = []

        for step_idx in range(num_steps):
            t = t_schedule[step_idx]
            c, z_t = self.attribution_step(x=x, y=y, t=t, clamp=clamp)
            c = c.detach() * x.detach()

            q_low  = torch.quantile(c, 0.01, dim=-1, keepdim=True)
            q_high = torch.quantile(c, 0.99, dim=-1, keepdim=True)
            c = torch.clamp(c, q_low, q_high)

            buffer.append(c)

        # # === (EXPERIMENTAL) MAD outlier removal ===
        c = torch.stack(buffer, dim=0)

        med = c.median(dim=0).values

        mad = (c - med).abs().median(dim=0).values
        scale = 1.4826 * mad + 1e-8

        K = 2.5  # 2.5 – 4.0
        mask = (c - med).abs() <= K * scale

        c = (c * mask).sum(dim=0)
        den = mask.sum(dim=0).clamp(min=1)

        c = c / den
        # c = c*x.detach()

        if post_proc:
            c = self.post_proc(c)

        return {'attribution': c}

    def attribution_step(
        self,
        x: Tensor,
        y: Tensor,
        t: Tensor,
        clamp: bool = False,
    ) -> Tensor:

        x = self._clone_input(x)
        w_eff, z = self.get_w_eff(x, y, t)

        if x.ndim < w_eff.ndim:
            x = x.unsqueeze(-1)

        c = w_eff.detach()

        if clamp:
            c = quantile_clamp(c, q_high=0.99)
        return c, z

    def get_w_eff(
        self,
        x: Tensor,
        y: Optional[Tensor],
        t: Tensor,
    ) -> Tensor:

        self._check_batch_shapes(x, y, t)

        if y is None:
            raise NotImplementedError(
                "y=None not yet supported in get_w_eff"
            )

        single_logit = (y.shape[-1] == 1)

        if single_logit:
            z = self.pred_batch(x, y, t)
            z = z.squeeze(-1)
            w_eff = torch.autograd.grad(
                outputs=z,
                inputs=[x],
                grad_outputs=torch.ones_like(z),
                retain_graph=False,
            )[0]
        else:
            N = x.shape[0]
            t = t.unsqueeze(0).repeat(N)
            w_fn = jacrev(self.pred_smp, argnums=0)
            w_eff = vmap(w_fn, randomness="same")(x, y, t)
            w_eff = w_eff.permute(0, 2, 3, 4, 1)

        return w_eff, z

    def pred_smp(
        self,
        x: Tensor,
        y: Optional[Tensor],
        t: Tensor,
    ) -> Tensor:

        self._check_smp_shapes(x, y, t)

        x = x.unsqueeze(0)

        if y is not None:
            y = y.unsqueeze(0)

        z = self.pred_batch(x, y, t)
        return z.squeeze(0)

    def pred_batch(
        self,
        x: Tensor,
        y: Optional[Tensor],
        t: Tensor,
    ) -> Tensor:

        self._check_batch_shapes(x, y, t)

        x = self.aug(x)
        x = self.prog_noise(x, t)
        z = self.model(x)

        if y is not None:
            y = y.long()
            z = z.gather(dim=1, index=y)

        return z

    def prog_noise(self, x: Tensor, t: Tensor) -> Tensor:
        torch.seed()
        noise = torch.randn_like(x, device=x.device)
        #noise = generate_noise(x)
        #x = (1.0 - t) * x + torch.sqrt(1.0 - (1.0 - t)**2) * noise
        #x = x + torch.sin(t)*noise
        x = x + noise
        return x

    def detach_mode(self):
        self._detach_gelu(self.model)
        self._detach_attention(self.model)
        self._detach_layer_norm(self.model)

    def _detach_attention(self, model):
        for blk in model.blocks:
            attn_module = blk.attn
            attn_module._orig_forward = attn_module.forward
            omenn_forward = _make_detach_attn_forward()
            attn_module.forward = types.MethodType(omenn_forward, attn_module)

    def _detach_gelu(self, model):
        for name, module in model.named_children():
            if isinstance(module, nn.GELU):
                setattr(model, name, GELUDetached())
            else:
                self._detach_gelu(module)

    def _detach_layer_norm(self, model):
        for name, module in model.named_children():
            if isinstance(module, nn.LayerNorm):
                new_forward = _make_centering_layer_norm(module)
                module._orig_forward = module.forward
                module.forward = types.MethodType(new_forward, module)
            else:
                self._detach_layer_norm(module)

    def _clone_input(self, x: Tensor) -> Tensor:
        return x.clone().detach().requires_grad_(True)

    def _check_smp_shapes(
        self,
        x: Tensor,
        y: Optional[Tensor],
        t: Tensor,
    ):
        assert x.ndim == 3, "Expected single image sample!"
        assert t.ndim == 0, "Expected single noise level!"
        if y is not None:
            assert y.ndim == 1, "Expected single sample labels!"

    def _check_batch_shapes(
        self,
        x: Tensor,
        y: Optional[Tensor],
        t: Tensor,
    ):
        assert x.ndim == 4, "Expected batch of image samples!"
        assert t.ndim == 0, "Expected batch of noise levels!"
        if y is not None:
            assert y.ndim == 2, "Expected batch of sample labels!"
