import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.func import functional_call
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def inconsistencyLoss(model, image, label, criterion, beta, rho, noise_scale):
    criterion_kl = nn.KLDivLoss(reduction='batchmean')
    params = dict(model.named_parameters())
    buffers = dict(model.named_buffers())

    pred = functional_call(model, (params, buffers), (image,))
    pred_soft = F.softmax(pred, dim=1).clamp(min=1e-6, max=1.0)

    # Weight Initialization with Noise
    noise_norm = math.sqrt(sum(p.numel() for p in model.parameters() if p.requires_grad))
    noise_dict = {}
    for name, param in model.named_parameters():
        noise_dict[name] = noise_scale * torch.normal(0, 1, size=param.data.shape, device=device) / noise_norm
        param.data += noise_dict[name]
    noise_output = model(image)

    # Gradient ascent
    with torch.enable_grad():
        loss_kl = criterion_kl(F.log_softmax(noise_output, dim=1), pred_soft.detach())

    model.zero_grad()
    loss_kl.backward(retain_graph=True)
    grads = [param.grad.clone() for param in model.parameters() if param.requires_grad]
    wgrads = [torch.norm(param.grad, p=2) for param in model.parameters() if param.requires_grad]
    norm = torch.norm(torch.stack(wgrads), p=2) + 1e-12

    delta_dict = {}
    with torch.no_grad():
        for (name, param), grad in zip(model.named_parameters(), grads):
            delta_dict[name] = (rho * grad / norm).detach()
            param.data -= noise_dict[name]

    perturbed_params = {n: p + delta_dict[n] for (n, p) in params.items()}

    output_prime = functional_call(model, (perturbed_params, buffers), (image,))

    p = pred_soft
    log_q = F.log_softmax(output_prime, dim=1)
    inconsistency = F.kl_div(log_q, p, reduction='batchmean')

    loss = criterion(pred, label)

    return loss, (beta * inconsistency)


def inconsistency_semi(model, image, val_image, label, criterion, beta, rho, noise_scale):
    criterion_kl = nn.KLDivLoss(reduction='batchmean')
    params = dict(model.named_parameters())
    buffers = dict(model.named_buffers())

    ce_pred = functional_call(model, (params, buffers), (image,))
    val_pred = functional_call(model, (params, buffers), (val_image,))

    image = torch.cat([image, val_image], dim=0)

    pred = torch.cat([ce_pred, val_pred], dim=0)
    pred_soft = F.softmax(pred, dim=1).clamp(min=1e-6, max=1.0)

    # Weight Initialization with Noise
    noise_norm = math.sqrt(sum(p.numel() for p in model.parameters() if p.requires_grad))
    noise_dict = {}
    for name, param in model.named_parameters():
        noise_dict[name] = noise_scale * torch.normal(0, 1, size=param.data.shape, device=device) / noise_norm
        param.data += noise_dict[name]

    noise_output = model(image)

    # Gradient ascent
    with torch.enable_grad():
        loss_kl = criterion_kl(F.log_softmax(noise_output, dim=1), pred_soft.detach())

    model.zero_grad()
    loss_kl.backward(retain_graph=True)
    grads = [param.grad.clone() for param in model.parameters() if param.requires_grad]
    wgrads = [torch.norm(param.grad, p=2) for param in model.parameters() if param.requires_grad]
    norm = torch.norm(torch.stack(wgrads), p=2) + 1e-12

    delta_dict = {}
    with torch.no_grad():
        for (name, param), grad in zip(model.named_parameters(), grads):
            delta_dict[name] = (rho * grad / norm).detach()
            param.data -= noise_dict[name]

    perturbed_params = {n: p + delta_dict[n] for (n, p) in params.items()}

    output_prime = functional_call(model, (perturbed_params, buffers), (image,))

    p = pred_soft
    log_q = F.log_softmax(output_prime, dim=1)
    inconsistency = F.kl_div(log_q, p, reduction='batchmean')

    loss = criterion(ce_pred, label)

    return loss, beta * inconsistency


_BN_TYPES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)

def _disable_running_stats(model: nn.Module) -> None:
    """disable running stats of BN layers"""
    def _disable(m: nn.Module):
        if isinstance(m, _BN_TYPES):
            if not hasattr(m, "backup_momentum"): 
                m.backup_momentum = m.momentum
            m.momentum = 0.0 
    model.apply(_disable)

def _enable_running_stats(model: nn.Module) -> None:
    """enable running stats of BN layers"""
    def _enable(m: nn.Module):
        if isinstance(m, _BN_TYPES) and hasattr(m, "backup_momentum"):
            m.momentum = m.backup_momentum 
    model.apply(_enable)


class IAM_S(torch.optim.Optimizer):
    """
    제공된 훈련 로직을 통합한 커스텀 PyTorch 옵티마이저입니다.

    이 옵티마이저는 다음과 같은 단계로 작동합니다:
    1. 초기 손실 및 그래디언트(dl)를 계산합니다.
    2. 노이즈를 추가한 가중치에서 KL 발산 손실의 그래디언트(g)를 기반으로 섭동(delta_adv)을 계산합니다.
    3. 이 섭동을 원본 가중치에 적용한 후, 이 지점에서 새로운 손실(Loss_S) 및 그래디언트를 계산합니다.
    4. 원본 가중치를 복원하고, Loss_S에서 계산된 그래디언트를 사용하여 기본 옵티마이저 업데이트를 수행합니다.
    """
    def __init__(self, params, base_optimizer_cls, rho: float, noise_scale: float, **kwargs):
        """
        init IAM-S optimizer.
        """
        if rho < 0.0:
            raise ValueError(f"Invalid rho, should be non-negative: {rho}")
        if noise_scale < 0.0:
            raise ValueError(f"Invalid noise_scale, should be non-negative: {noise_scale}")

        # k_val # of parameters
        defaults = dict(rho=rho, noise_scale=noise_scale)
        super(IAM_S, self).__init__(params, defaults)

        # k_val 
        self.k_val = 0
        for group in self.param_groups:
            for p in group["params"]:
                if p.requires_grad:
                    self.k_val += p.numel()
        
        if self.k_val == 0:
            raise ValueError("Optimizer initialized with no trainable parameters.")

        # init base optimizer
        self.base_optimizer = base_optimizer_cls(self.param_groups, **kwargs)
        # self.param_groups = self.base_optimizer.param_groups # already shared

        
        if len(self.param_groups) > 0 and len(self.param_groups[0]["params"]) > 0 :
            self.device = self.param_groups[0]["params"][0].device
        else:
            self.device = torch.device("cpu")
            if self.k_val > 0:
                 print("Warning: Optimizer has k_val > 0 but could not determine device from param_groups. Defaulting to CPU.")


        self.criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=False).to(self.device)

    @torch.no_grad()
    def step(self, closure_main_loss, model_nn_module, inputs_for_model):
        """
        perform optimization step
        """
        if closure_main_loss is None:
            raise ValueError("closure_main_loss is required for this optimizer's step.")

        self.base_optimizer.zero_grad() 

        with torch.enable_grad(): 
            loss_original, outputs_original = closure_main_loss()

        # dl 
        dl_list = []
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    dl_list.append(p.grad.data.clone().flatten())
        
        if not dl_list:
            print("Warning: No gradients (dl) found after initial closure. Skipping optimizer step.")
            return loss_original.detach() if loss_original is not None else None, None

        dl_vector = torch.cat(dl_list).detach() # detach for later computations

        # p_ref (original output) calculate. maybe skip with well optimized version
        eps = torch.finfo(outputs_original.dtype).tiny if outputs_original.dtype.is_floating_point else 1e-8
        p_ref = F.softmax(outputs_original, dim=1).clamp_min(eps).detach()

        # save original weight
        w_orig_vector = torch.nn.utils.parameters_to_vector(model_nn_module.parameters()).detach().clone()
        for group in self.param_groups: 
            for p in group["params"]:
                self.state[p]["old_p"] = p.data.clone()

        # --- delta_adv ---
        _disable_running_stats(model_nn_module) # disable BN stat update

        # generate noise vector
        d_noise = self.defaults['noise_scale'] / math.sqrt(self.k_val + 1e-12) * \
                  torch.randn_like(w_orig_vector, device=self.device)
        
        torch.nn.utils.vector_to_parameters(w_orig_vector + d_noise, model_nn_module.parameters())

        grad_g_vector = None 
        with torch.enable_grad(): 
            logit_noise = model_nn_module(inputs_for_model)
            
            active_params_for_g = [p for p in model_nn_module.parameters() if p.requires_grad]
            if not active_params_for_g:
                _enable_running_stats(model_nn_module)
                print("Warning: No active parameters for 'g' calculation. Restoring original weights and skipping perturbation.")
                torch.nn.utils.vector_to_parameters(w_orig_vector, model_nn_module.parameters())
                self.base_optimizer.step()
                return loss_original.detach(), None

            loss_noise_val = self.criterion_kl(F.log_softmax(logit_noise, dim=1), p_ref)
            
            if torch.isnan(loss_noise_val): 
                 _enable_running_stats(model_nn_module)
                 print("Warning: NaN in loss_noise_val. Restoring original weights and skipping perturbation.")
                 torch.nn.utils.vector_to_parameters(w_orig_vector, model_nn_module.parameters())
                 self.base_optimizer.step() 
                 return loss_original.detach(), None

           
            grad_g_list = torch.autograd.grad(loss_noise_val, active_params_for_g, create_graph=False)
            
            flat_grad_g_list = [g.flatten() for g in grad_g_list if g is not None]
            if not flat_grad_g_list :
                 _enable_running_stats(model_nn_module)
                 print("Warning: Gradient 'g' is empty. Restoring original weights and using original gradients.")
                 torch.nn.utils.vector_to_parameters(w_orig_vector, model_nn_module.parameters())
                 self.base_optimizer.step() 
                 return loss_original.detach(), None
            grad_g_vector = torch.cat(flat_grad_g_list)

        # delta_adv: ρ·g/‖g‖
        delta_adv = self.defaults['rho'] * grad_g_vector / (grad_g_vector.norm(p=2) + 1e-12)
        
        # define sign of delta_adv with dl, this process is optional process to confirm loss increase with delta_adv.
        if dl_vector.numel() == delta_adv.numel():
            delta_adv = torch.sign(torch.dot(delta_adv, dl_vector)) * delta_adv
        else:
            print(f"Warning: dl_vector size ({dl_vector.numel()}) and delta_adv (from g) size ({delta_adv.numel()}) mismatch. Skipping sign alignment for delta.")

        # --- (Loss_S at w_orig + delta_adv) ---
        torch.nn.utils.vector_to_parameters(w_orig_vector + delta_adv, model_nn_module.parameters())

        self.base_optimizer.zero_grad()
        with torch.enable_grad(): 
            loss_s, _ = closure_main_loss() # p.grad = d(Loss_S)/d(w_orig + delta_adv)

        if torch.isnan(loss_s): 
            print(f"NaN detected in Loss_S at step. Optimizer will use these NaN gradients from Loss_S.")

        # --- restore to original parameters 
        for group in self.param_groups: 
            for p in group["params"]:
                if p in self.state and "old_p" in self.state[p]:
                    p.data = self.state[p]["old_p"]
                # else: do not occur theoretically

        # updata w_orig with d(Loss_S)/d(w_orig + delta_adv) 
        self.base_optimizer.step() 
        
        _enable_running_stats(model_nn_module) # enable BN stat 

        return loss_original.detach(), loss_s.detach() # return two loss for logging
    
def compute_S(model, inputs, rho=0.1, noise_scale=0.05, K=1):
    """
    Algorithm 1. 
    Compute the local inconsistency minimal computation graph
    Can't be used for loss.
    """
    criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=False)
    original_mode = model.training
    original_requires_grad = {name: p.requires_grad for name, p in model.named_parameters()}
    model.eval()
    eps = torch.finfo(inputs.dtype).tiny
    # prediction with W_org without graph
    with torch.no_grad():
        pred_orig_softmax = F.softmax(model(inputs), dim=1)
        p_orig=pred_orig_softmax.clamp(eps)

    original_params_flat = torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()
    params_structure = list(model.parameters())  

    k = original_params_flat.numel()

    noise = noise_scale / (math.sqrt(k)) * torch.randn_like(original_params_flat)
    noisy_params_flat = original_params_flat + noise
    torch.nn.utils.vector_to_parameters(noisy_params_flat, params_structure)

    for param in params_structure:
        param.requires_grad_(True)

    outputs_noisy = model(inputs)
    eps = torch.finfo(outputs_noisy.dtype).tiny
    outputs_noisy = F.softmax(outputs_noisy, dim=1).clamp(eps)
    loss = criterion_kl(outputs_noisy.log(), p_orig)

    for k in range(K):
        model.zero_grad()
        loss.backward()

        grad_vector_list = []
        for param in params_structure:
            if param.requires_grad:
                grad = param.grad if param.grad is not None else torch.zeros_like(param.data)
                grad_vector_list.append(grad.detach().flatten())

        if not grad_vector_list:
            print("[Error] No parameters require gradients.")
            torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)
            return float('nan')

        grad_vector_flat = torch.cat(grad_vector_list)
        grad_norm = grad_vector_flat.norm()

        # 7) 그래디언트 방향으로 rho만큼 Ascent
        ascent_perturbation = rho * (grad_vector_flat / (grad_norm + 1e-12))
        params_after_ascent_flat = original_params_flat + ascent_perturbation
        torch.nn.utils.vector_to_parameters(params_after_ascent_flat, params_structure)


        outputs_after_ascent = model(inputs)
        p_after_ascent = F.softmax(outputs_after_ascent, dim=1).clamp(eps)

        
        loss = criterion_kl(p_after_ascent.log(), p_orig)

    # 원본 파라미터 복원
    torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)
    for name, p in model.named_parameters():
        if name in original_requires_grad:
            p.requires_grad_(original_requires_grad[name])
    # 모델 모드 복원
    model.train(original_mode)

    return loss.item()

def S_PGA(model, inputs, rho=0.1, noise_scale=0.05, K=5):
    """
    PGA
    Compute the local inconsistency with Projected gradient ascent 
    Can't be used for loss.
    """
    criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=False)
    original_mode = model.training
    original_requires_grad = {name: p.requires_grad for name, p in model.named_parameters()}
    model.eval()
    eps = torch.finfo(inputs.dtype).tiny

    with torch.no_grad():
        pred_orig_softmax = F.softmax(model(inputs), dim=1)
        p_orig=pred_orig_softmax.clamp(eps)


    original_params_flat = torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()
    params_structure = list(model.parameters()) 

    k = original_params_flat.numel()

  
    noise = noise_scale / (math.sqrt(k)) * torch.randn_like(original_params_flat)
    noisy_params_flat = original_params_flat + noise
    torch.nn.utils.vector_to_parameters(noisy_params_flat, params_structure)


    for param in params_structure:
        param.requires_grad_(True)

    outputs_noisy = model(inputs)
    eps = torch.finfo(outputs_noisy.dtype).tiny
    outputs_noisy = F.softmax(outputs_noisy, dim=1).clamp(eps)
    loss = criterion_kl(outputs_noisy.log(), p_orig)

    ascent_perturbation = torch.zeros_like(original_params_flat)

    for k in range(K):
        model.zero_grad()
        loss.backward()

        grad_vector_list = []
        for param in params_structure:
            if param.requires_grad:

                grad = param.grad if param.grad is not None else torch.zeros_like(param.data)
                grad_vector_list.append(grad.detach().flatten())

        if not grad_vector_list:
            print("[Error] No parameters require gradients.")
            torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)
            return float('nan')

        grad_vector_flat = torch.cat(grad_vector_list)
        grad_norm = grad_vector_flat.norm()


        ascent_perturbation = ascent_perturbation + rho * (grad_vector_flat / (grad_norm + 1e-12))
        ascent_perturbation = rho / ascent_perturbation.norm() * ascent_perturbation
        params_after_ascent_flat = original_params_flat + ascent_perturbation
        with torch.no_grad():
            torch.nn.utils.vector_to_parameters(params_after_ascent_flat, params_structure)


        outputs_after_ascent = model(inputs)
        p_after_ascent = F.softmax(outputs_after_ascent, dim=1).clamp(eps)

        
        loss = criterion_kl(p_after_ascent.log(), p_orig)


    torch.nn.utils.vector_to_parameters(original_params_flat, params_structure)
    for name, p in model.named_parameters():
        if name in original_requires_grad:
            p.requires_grad_(original_requires_grad[name])

    model.train(original_mode)

    return loss.item()
