import torch
import torch.nn as nn
from models.common import LoRALinear

class RaGrad:
    def __init__(self, gradient_mean_rgb, gradient_stability_rgb, 
                 gradient_mean_ir, gradient_stability_ir,
                 alpha=0.95, eps=1e-8, n_samples=3):
        
        self.gradient_mean_rgb = gradient_mean_rgb
        self.gradient_stability_rgb = gradient_stability_rgb
        self.gradient_mean_ir = gradient_mean_ir
        self.gradient_stability_ir = gradient_stability_ir
        self.alpha = alpha
        self.eps = eps
        self.n_samples = n_samples

    @staticmethod
    def normalize(t1, t2, eps=1e-8):
        combined = torch.cat([t1, t2], dim=0)
        if torch.allclose(combined, combined[0], atol=eps):
            return t1, t2
        mean = combined.mean()
        std = combined.std() + eps
        z_score = (combined - mean) / std
        norm = torch.sigmoid(z_score)
        t1_norm, t2_norm = torch.split(norm, [t1.shape[0], t2.shape[0]], dim=0)
        return t1_norm, t2_norm

    @staticmethod
    def vector_proj_batched(a: torch.Tensor, b: torch.Tensor, eps=1e-6):
        dot = (a * b).sum(dim=1, keepdim=True)
        norm_sq = (b * b).sum(dim=1, keepdim=True) + eps
        scale = dot / norm_sq
        return scale * b


    def estimate_hessian_diag(self, loss, params):
        grads = torch.autograd.grad(loss, params, create_graph=True)
        hessian_diag = {id(p): torch.zeros_like(p) for p in params}

        for _ in range(self.n_samples):
            zs = [torch.randint(0, 2, g.shape, device=g.device, dtype=torch.float32) * 2 - 1 for g in grads]
            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=False)
            for p, z, h_z in zip(params, zs, h_zs):
                hessian_diag[id(p)] += (z * h_z) / self.n_samples

        return hessian_diag
    
    def gradient_correction(self, model, loss, batch_count):

        real_model = model.module if hasattr(model, 'module') else model

        rgb_backbone = []
        ir_backbone = []

        rgb_layers = [real_model.model[i] for i in rgb_backbone]
        ir_layers = [real_model.model[i] for i in ir_backbone]

        rgb_params = [
            p for layer in rgb_layers
            for module in layer.modules()
            if isinstance(module, nn.Conv2d) and module.kernel_size == (3, 3)
            for p in module.parameters()
        ]

        ir_params = [
            p for layer in ir_layers
            for module in layer.modules()
            if isinstance(module, nn.Conv2d) and module.kernel_size == (3, 3)
            for p in module.parameters()
        ]

        if len(rgb_params) == 0 or len(ir_params) == 0:
            return 0

        all_params = rgb_params + ir_params
        

        hessian_diag = self.estimate_hessian_diag(loss, all_params) 

        for p_rgb, p_ir in zip(rgb_params, ir_params):

            pid_rgb = id(p_rgb)
            pid_ir = id(p_ir)
            decay_rate = batch_count / (batch_count + 1)

            g_rgb = p_rgb.grad.detach()
            g_ir = p_ir.grad.detach()
            delta_rgb = p_rgb.detach()
            delta_ir = p_ir.detach()
            hess_diag_rgb = hessian_diag[pid_rgb]
            hess_diag_ir  = hessian_diag[pid_ir]

            et_rgb = -g_rgb * delta_rgb + 0.5 * hess_diag_rgb * (delta_rgb ** 2 + self.eps)
            E_rgb = torch.sum(et_rgb, dim=(1, 2, 3), keepdim=True).squeeze(-1)

            et_ir = -g_ir * delta_ir + 0.5 * hess_diag_ir * (delta_ir ** 2 + self.eps)
            E_ir = torch.sum(et_ir, dim=(1, 2, 3), keepdim=True).squeeze(-1)

            self.gradient_mean_rgb[pid_rgb] = decay_rate * self.gradient_mean_rgb.get(pid_rgb, g_rgb.clone()) + \
                                              (1 - decay_rate) * g_rgb
            mean_rgb = self.gradient_mean_rgb[pid_rgb]

            S_prev_rgb = self.gradient_stability_rgb.get(pid_rgb, torch.zeros_like(p_rgb))
            stability_rgb = self.alpha * S_prev_rgb + (1 - self.alpha) * ((g_rgb - mean_rgb) ** 2)
            self.gradient_stability_rgb[pid_rgb] = stability_rgb
            S_rgb = torch.mean(stability_rgb, dim=(1, 2, 3))

            self.gradient_mean_ir[pid_ir] = decay_rate * self.gradient_mean_ir.get(pid_ir, g_ir.clone()) + \
                                            (1 - decay_rate) * g_ir
            mean_ir = self.gradient_mean_ir[pid_ir]

            S_prev_ir = self.gradient_stability_ir.get(pid_ir, torch.zeros_like(p_ir))
            stability_ir = self.alpha * S_prev_ir + (1 - self.alpha) * ((g_ir - mean_ir) ** 2)
            self.gradient_stability_ir[pid_ir] = stability_ir
            S_ir = torch.mean(stability_ir, dim=(1, 2, 3))

            E_norm_rgb, E_norm_ir = self.normalize(E_rgb, E_ir, self.eps)
            S_norm_rgb, S_norm_ir = self.normalize(S_rgb, S_ir, self.eps)

            R_rgb = E_norm_rgb * S_norm_rgb
            R_ir = E_norm_ir * S_norm_ir

            g_rgb_flat = g_rgb.view(g_rgb.shape[0], -1)
            g_ir_flat = g_ir.view(g_ir.shape[0], -1)

            dot_product = (g_rgb_flat * g_ir_flat).sum(dim=1)
            norm_product = torch.norm(g_rgb_flat, dim=1) * torch.norm(g_ir_flat, dim=1)
            cos_sim = dot_product / (norm_product + self.eps)
            conflict_mask = cos_sim < 0
            conflict_cos = cos_sim[conflict_mask]

            if conflict_cos.numel() > 0:
                cos_mean = conflict_cos.mean()
                conflict_idx = (cos_sim < cos_mean).nonzero(as_tuple=False).squeeze(1)
            else:
                continue

            reliability_gap = torch.abs(R_rgb - R_ir)
            gap_threshold = reliability_gap.mean()
            gap_idx = (reliability_gap > gap_threshold).nonzero(as_tuple=False).squeeze(1)

            intersection = list(set(conflict_idx.tolist()) & set(gap_idx.tolist()))
            if len(intersection) == 0:
                continue

            idx = torch.tensor(intersection, device=g_rgb.device)
            

            rgb_vec = g_rgb_flat[idx]
            ir_vec = g_ir_flat[idx]
            rgb_score = R_rgb[idx]
            ir_score = R_ir[idx]

            gap = torch.abs(rgb_score - ir_score)
            sum_score = rgb_score + ir_score + self.eps
            lambda_raw = torch.sigmoid(gap / sum_score)

            rgb_better = rgb_score > ir_score
            idx_rgb = idx[rgb_better]
            idx_ir = idx[~rgb_better]

            lamb_rgb = lambda_raw[rgb_better].unsqueeze(1)
            lamb_ir = lambda_raw[~rgb_better].unsqueeze(1)

            proj_ir_on_rgb = self.vector_proj_batched(ir_vec[rgb_better], rgb_vec[rgb_better], self.eps)
            proj_rgb_on_ir = self.vector_proj_batched(rgb_vec[~rgb_better], ir_vec[~rgb_better], self.eps)

            g_ir_flat[idx_rgb] = (1 - lamb_rgb) * (ir_vec[rgb_better] - proj_ir_on_rgb) + lamb_rgb * rgb_vec[rgb_better]
            g_rgb_flat[idx_ir] = (1 - lamb_ir) * (rgb_vec[~rgb_better] - proj_rgb_on_ir) + lamb_ir * ir_vec[~rgb_better]

            p_rgb.grad.copy_(g_rgb_flat.view_as(p_rgb.grad))
            p_ir.grad.copy_(g_ir_flat.view_as(p_ir.grad))

        return True

# backbone LoRA
class LoRaGrad:
    def __init__(self,
                 gradient_mean_rgb, gradient_stability_rgb,
                 gradient_mean_ir, gradient_stability_ir,
                 alpha=0.95, eps=1e-8, n_samples=3):
        self.gradient_mean_rgb = gradient_mean_rgb
        self.gradient_stability_rgb = gradient_stability_rgb
        self.gradient_mean_ir = gradient_mean_ir
        self.gradient_stability_ir = gradient_stability_ir
        self.alpha = alpha
        self.eps = eps
        self.n_samples = n_samples

    @staticmethod
    def normalize(t1, t2, eps=1e-8):
        combined = torch.cat([t1, t2], dim=0)
        if torch.allclose(combined, combined[0], atol=eps):
            return t1, t2
        mean = combined.mean()
        std = combined.std() + eps
        z_score = (combined - mean) / std
        norm = torch.sigmoid(z_score)
        return torch.split(norm, [t1.shape[0], t2.shape[0]], dim=0)

    @staticmethod
    def vector_proj_batched(a: torch.Tensor, b: torch.Tensor, eps=1e-6):
        dot = (a * b).sum(dim=1, keepdim=True)
        norm_sq = (b * b).sum(dim=1, keepdim=True) + eps
        return (dot / norm_sq) * b

    @staticmethod
    def get_lora_params(layers):
        params = []
        for layer in layers:
            for module in layer.modules():
                if isinstance(module, LoRALinear):  
                    for name, p in module.named_parameters():
                        if 'lora_B' in name and p.requires_grad:
                            params.append(p)
        return params

    def estimate_hessian_diag(self, loss, params):
        grads = torch.autograd.grad(loss, params, create_graph=True)
        hessian_diag = {id(p): torch.zeros_like(p) for p in params}
        for _ in range(self.n_samples):
            zs = [torch.randint(0, 2, g.shape, device=g.device, dtype=torch.float32) * 2 - 1 for g in grads]
            h_zs = torch.autograd.grad(grads, params, grad_outputs=zs, only_inputs=True, retain_graph=False)
            for p, z, h_z in zip(params, zs, h_zs):
                hessian_diag[id(p)] += (z * h_z) / self.n_samples
        return hessian_diag

    def correct(self, model, loss, batch_count):
        real_model = model.module if hasattr(model, 'module') else model
        rgb_layers = [real_model.model[0]]
        ir_layers = [real_model.model[1]]

        rgb_params = self.get_lora_params(rgb_layers)
        ir_params = self.get_lora_params(ir_layers)

        if len(rgb_params) == 0 or len(ir_params) == 0:
            return False

        all_params = rgb_params + ir_params
        hessian_diag = self.estimate_hessian_diag(loss, all_params)

        for p_rgb, p_ir in zip(rgb_params, ir_params):
            pid_rgb = id(p_rgb)
            pid_ir = id(p_ir)
            decay_rate = batch_count / (batch_count + 1)

            g_rgb = p_rgb.grad.detach()
            g_ir = p_ir.grad.detach()
            delta_rgb = p_rgb.detach()
            delta_ir = p_ir.detach()
            hess_diag_rgb = hessian_diag[pid_rgb]
            hess_diag_ir = hessian_diag[pid_ir]

            et_rgb = -g_rgb * delta_rgb + 0.5 * hess_diag_rgb * (delta_rgb ** 2 + self.eps)
            E_rgb = torch.mean(et_rgb, dim=-1, keepdim=True)

            et_ir = -g_ir * delta_ir + 0.5 * hess_diag_ir * (delta_ir ** 2 + self.eps)
            E_ir = torch.mean(et_ir, dim=-1, keepdim=True)

            self.gradient_mean_rgb[pid_rgb] = decay_rate * self.gradient_mean_rgb.get(pid_rgb, g_rgb.clone()) + \
                                              (1 - decay_rate) * g_rgb
            mean_rgb = self.gradient_mean_rgb[pid_rgb]

            S_prev_rgb = self.gradient_stability_rgb.get(pid_rgb, torch.zeros_like(p_rgb))
            stability_rgb = self.alpha * S_prev_rgb + (1 - self.alpha) * ((g_rgb - mean_rgb) ** 2)
            self.gradient_stability_rgb[pid_rgb] = stability_rgb
            S_rgb = torch.mean(stability_rgb, dim=-1, keepdim=True)

            self.gradient_mean_ir[pid_ir] = decay_rate * self.gradient_mean_ir.get(pid_ir, g_ir.clone()) + \
                                            (1 - decay_rate) * g_ir
            mean_ir = self.gradient_mean_ir[pid_ir]

            S_prev_ir = self.gradient_stability_ir.get(pid_ir, torch.zeros_like(p_ir))
            stability_ir = self.alpha * S_prev_ir + (1 - self.alpha) * ((g_ir - mean_ir) ** 2)
            self.gradient_stability_ir[pid_ir] = stability_ir
            S_ir = torch.mean(stability_ir, dim=-1, keepdim=True)

            E_norm_rgb, E_norm_ir = self.normalize(E_rgb, E_ir, self.eps)
            S_norm_rgb, S_norm_ir = self.normalize(S_rgb, S_ir, self.eps)

            R_rgb = E_norm_rgb * S_norm_rgb
            R_ir = E_norm_ir * S_norm_ir

            dot_product = (g_rgb * g_ir).sum(dim=1)
            norm_product = torch.norm(g_rgb, dim=1) * torch.norm(g_ir, dim=1)
            cos_sim = dot_product / (norm_product + self.eps)
            conflict_mask = cos_sim < 0
            conflict_cos = cos_sim[conflict_mask]

            if conflict_cos.numel() > 0:
                cos_mean = conflict_cos.mean()
                conflict_idx = (cos_sim < cos_mean).nonzero(as_tuple=False).squeeze(1)
            else:
                continue

            reliability_gap = torch.abs(R_rgb - R_ir)
            gap_threshold = reliability_gap.mean()
            gap_idx = (reliability_gap > gap_threshold).nonzero(as_tuple=False).squeeze(1)

            intersection = list(set(conflict_idx.tolist()) & set(gap_idx.tolist()))
            if len(intersection) == 0:
                continue

            idx = torch.tensor(intersection, device=g_rgb.device)
            rgb_vec = g_rgb[idx]
            ir_vec = g_ir[idx]
            rgb_score = R_rgb[idx].squeeze(-1)
            ir_score = R_ir[idx].squeeze(-1)

            gap = torch.abs(rgb_score - ir_score)
            sum_score = rgb_score + ir_score + self.eps
            lambda_raw = torch.sigmoid(gap / sum_score)

            rgb_better = rgb_score > ir_score
            idx_rgb = idx[rgb_better]
            idx_ir = idx[~rgb_better]

            lamb_rgb = lambda_raw[rgb_better].unsqueeze(1)
            lamb_ir = lambda_raw[~rgb_better].unsqueeze(1)

            proj_ir_on_rgb = self.vector_proj_batched(ir_vec[rgb_better], rgb_vec[rgb_better], self.eps)
            proj_rgb_on_ir = self.vector_proj_batched(rgb_vec[~rgb_better], ir_vec[~rgb_better], self.eps)

            g_ir[idx_rgb] = (1 - lamb_rgb) * (ir_vec[rgb_better] - proj_ir_on_rgb) + lamb_rgb * rgb_vec[rgb_better]
            g_rgb[idx_ir] = (1 - lamb_ir) * (rgb_vec[~rgb_better] - proj_rgb_on_ir) + lamb_ir * ir_vec[~rgb_better]

            p_rgb.grad.copy_(g_rgb)
            p_ir.grad.copy_(g_ir)

        return True
