import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random

# LwF
def lwf_loss(current_model, current_out, old_model, features, temperature=2.0):
    """
    Compute the distillation loss between current and old model outputs (LwF style).

    Args:
        current_model (nn.Module): The current (training) model.
        current_out (torch.Tensor || list): Output of the current model of the training forward pass.
        old_model (nn.Module): A frozen copy of the model before learning new task.
        inputs (torch.Tensor): Input batch (no labels needed).
        temperature (float): Softmax temperature (higher = softer targets).

    Returns:
        torch.Tensor: KL divergence distillation loss.
    """
    old_model.eval()
    device = next(current_model.parameters()).device

    old_model = old_model.to(device)

    # match the ids of the parameters of both models
    # example:
    # old = {"a": 0, "b": 1, "c": 2}
    # new = {"c": 0, "b": 1, "d": 2, "e": 3}

    # [new[i] for i in old.keys() if i in new.keys()] -> [1, 0]
    # [old[i] for i in new.keys() if i in old.keys()] -> [2, 1]
    relevant_params_old_model = [old_model.vocab_map[vocab] for vocab in current_model.vocab_map.keys() if vocab in old_model.vocab_map]
    relevant_params_current_model = [current_model.vocab_map[vocab] for vocab in old_model.vocab_map.keys() if vocab in current_model.vocab_map]
    if current_model.unknown_idx is not None:
        relevant_params_current_model.append(current_model.unknown_idx)
    if old_model.unknown_idx is not None:
        relevant_params_old_model.append(old_model.unknown_idx)

    # map the features so it works for model_old
    new_idx_to_old_map = {
        current_model.vocab_map[vocab]: old_model.vocab_map.get(vocab, old_model.unknown_idx)
        for vocab in current_model.vocab_map.keys()
    }
    if current_model.unknown_idx is not None:
        new_idx_to_old_map[current_model.unknown_idx] = old_model.unknown_idx
    features = torch.LongTensor([new_idx_to_old_map[i] for i in features.tolist()]).to(device)

    with torch.no_grad():
        old_out = old_model(features)

    # Wrap single-head outputs as list to unify logic
    if not isinstance(old_out, list):
        old_out = [old_out]
    if not isinstance(current_out, list):
        current_out = [current_out]
    
    # print("features shape:", features.shape)
    # print("current_out shape:", current_out[0].shape if isinstance(current_out, list) else current_out.shape)

    # print("relevant_params_current_model:", relevant_params_current_model)
    # print("relevant_params_old_model:", relevant_params_old_model)
    # print("current_out shape:", current_out[0].shape)
    # print("relevant_params_current_model:", max(relevant_params_current_model))
 
    
    loss = 0.0
    # Compute soft targets
    for old_logits, curr_logits in zip(old_out, current_out):
        if old_logits.shape[-1] == 1:
            # Binary classification case (1 output unit)
            # Use sigmoid + BCE-style KL 
            old_prob = torch.sigmoid(old_logits).to(device)
            curr_log_prob = torch.log(torch.sigmoid(curr_logits) + 1e-8)
            # Reverse KL-like term (soft distillation)
            kl = F.kl_div(curr_log_prob, old_prob, reduction='batchmean')
        else:
            # Multi-class case
            old_prob = F.softmax(old_logits[:, relevant_params_old_model] / temperature, dim=1).to(device)
            curr_log_prob = F.log_softmax(curr_logits[:, relevant_params_current_model] / temperature, dim=1)
            kl = F.kl_div(curr_log_prob, old_prob, reduction='batchmean') * (temperature ** 2)
        
        loss += kl

    return loss

def compute_quadratic_loss(current_params, prev_params, importance):
    def match_and_slice(tensor_a, tensor_b):
    # Slices both tensors along each dimension to the smaller size
        size = tuple(min(a, b) for a, b in zip(tensor_a.shape, tensor_b.shape))
        slices = tuple(slice(0, s) for s in size)
        return tensor_a[slices], tensor_b[slices]
    current_params, prev_params = match_and_slice(current_params, prev_params)
    _, importance = match_and_slice(current_params, importance)

    return (importance.to(current_params.device) * (current_params - prev_params.to(current_params.device)).pow(2)).sum()
# MAS
def mas_loss(current_model, old_model, importance_weights, lambda_mas=1.0):
    """
    Compute the MAS regularization loss to prevent forgetting.

    Args:
        current_model (nn.Module): Current model.
        old_model (nn.Module): A frozen copy of the model before learning new task.
        importance_weights (dict): {param_name: tensor} importance (omega) for each parameter.
        lambda_mas (float): Regularization strength.

    Returns:
        torch.Tensor: scalar MAS regularization loss
    """
    loss = 0.0
    prev_params = dict(old_model.named_parameters())
    # match the ids of the parameters of both models
    # example:
    # old = {"a": 0, "b": 1, "c": 2}
    # new = {"c": 0, "b": 1, "d": 2, "e": 3}

    # [new[i] for i in old.keys() if i in new.keys()] -> [1, 0]
    # [old[i] for i in new.keys() if i in old.keys()] -> [2, 1]
    relevant_params_old_model = [old_model.vocab_map[vocab] for vocab in current_model.vocab_map.keys() if vocab in old_model.vocab_map]
    relevant_params_current_model = [current_model.vocab_map[vocab] for vocab in old_model.vocab_map.keys() if vocab in current_model.vocab_map]
    if current_model.unknown_idx is not None:
        relevant_params_current_model.append(current_model.unknown_idx)
    if old_model.unknown_idx is not None:
        relevant_params_old_model.append(old_model.unknown_idx)


    for name, param in current_model.named_parameters():
        if param.requires_grad and name in importance_weights:
            prev_param = prev_params[name][relevant_params_old_model + ([old_model.padding_idx] if "embedding" in name and old_model.padding_idx is not None else [])]
            omega = importance_weights[name][relevant_params_old_model + ([old_model.padding_idx] if "embedding" in name and old_model.padding_idx is not None else [])]
            current_param = param[relevant_params_current_model + ([current_model.padding_idx] if "embedding" in name and current_model.padding_idx is not None else [])]

            loss += compute_quadratic_loss(current_param, prev_param, omega)
    return lambda_mas * loss

def compute_mas_importance_weights(model, dataloader, device = torch.device('cpu'), num_samples=None):
    """
    Compute MAS (Memory Aware Synapses) importance weights for each parameter.

    Args:
        model (nn.Module): The model after training on a task.
        dataloader (DataLoader): DataLoader for the task's data.
        device (str): Device to use.
        num_samples (int, optional): Max number of samples to use (if not using full set).

    Returns:
        dict: {param_name: importance tensor}
    """
    model.eval()
    importance = {name: torch.zeros_like(p, device=device) for name, p in model.named_parameters() if p.requires_grad}
    total_seen = 0

    for i, batch in enumerate(dataloader):
        x = batch[0].to(device)
        model.zero_grad()
        out = model(x)

        # Sum output elements to backprop through all outputs
        out_sum = sum([o.sum() for o in out]) if isinstance(out, list) else out.sum()
        out_sum.backward()

        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                importance[name] += param.grad.detach().abs()

        total_seen += x.size(0)
        if num_samples and total_seen >= num_samples:
            break

    for name in importance:
        importance[name] /= total_seen

    return importance

def ewc_loss(current_model, old_model, fisher_info, lambda_ewc=0.4):
    """
    Compute EWC penalty given model parameters and Fisher information.

    Args:
        current_model (nn.Module): Current model.
        old_model (nn.Module): A frozen copy of the model before learning new task.
        fisher_info (dict): {name: fisher (tensor)} Fisher information estimate
        lambda_ewc (float): EWC regularization strength

    Returns:
        torch.Tensor: scalar EWC penalty loss
    """
    loss = 0.0
    prev_params = dict(old_model.named_parameters())
    # match the ids of the parameters of both models
    # example:
    # old = {"a": 0, "b": 1, "c": 2}
    # new = {"c": 0, "b": 1, "d": 2, "e": 3}

    # [new[i] for i in old.keys() if i in new.keys()] -> [1, 0]
    # [old[i] for i in new.keys() if i in old.keys()] -> [2, 1]
    relevant_params_old_model = [old_model.vocab_map[vocab] for vocab in current_model.vocab_map.keys() if vocab in old_model.vocab_map]
    relevant_params_current_model = [current_model.vocab_map[vocab] for vocab in old_model.vocab_map.keys() if vocab in current_model.vocab_map]
    if current_model.unknown_idx is not None:
        relevant_params_current_model.append(current_model.unknown_idx)
    if old_model.unknown_idx is not None:
        relevant_params_old_model.append(old_model.unknown_idx)

    for name, param in current_model.named_parameters():
        # Only penalize parameters that exist in Fisher info
        if name in fisher_info:
            prev_param = prev_params[name][relevant_params_old_model + ([old_model.padding_idx] if "embedding" in name and old_model.padding_idx is not None else [])]
            fisher = fisher_info[name][name][relevant_params_old_model + ([old_model.padding_idx] if "embedding" in name and old_model.padding_idx is not None else [])]
            current_param = param[relevant_params_current_model + ([current_model.padding_idx] if "embedding" in name and current_model.padding_idx is not None else [])]

            loss += compute_quadratic_loss(current_param, prev_param, fisher)
    return lambda_ewc * loss

# EWC Fisher Information
def update_fisher_information(fisher_info, model, scale=1.0):
    """
    Update Fisher Information estimate using current parameter gradients.

    Args:
        fisher_info (dict): Running Fisher info estimate {param_name: tensor}.
        model (nn.Module): Model with gradients computed.
        scale (float): Scaling factor (usually 1 / N) to normalize contribution per batch/sample.
    """
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            if name not in fisher_info:
                fisher_info[name] = torch.zeros_like(param.grad)
            fisher_info[name] += scale * param.grad.detach() ** 2
            
# Gradiend and Activation Regularization
def compute_stabilization_regularization(
    activations,
    gradients,
    avg_activations,
    avg_gradients,
    lambda_grad=0.01,
    lambda_act=0.01,
    alpha=0.9,
):
    reg_loss = 0.0

    for name in gradients:
        grad = gradients[name]
        if grad.dim() > 2:
            grad = grad.squeeze(0).squeeze(0)
        elif grad.dim() == 2:
            grad = grad.squeeze(0)
        grad = grad.detach()

        # Resize or initialize moving average if needed
        if name not in avg_gradients:
            avg_gradients[name] = grad.clone()
        else:
            old_avg = avg_gradients[name]
            if old_avg.shape != grad.shape:
                # Resize by extending with zeros or trimming
                new_avg = torch.zeros_like(grad)
                num_to_copy = min(len(grad), len(old_avg))
                new_avg[:num_to_copy] = old_avg[:num_to_copy]
                avg_gradients[name] = new_avg

            diff = grad - avg_gradients[name]
            reg_loss = reg_loss + lambda_grad * (diff ** 2).mean()
            avg_gradients[name] = alpha * avg_gradients[name] + (1 - alpha) * grad

        # --- Activation regularization ---
        if name in activations:
            act = activations[name]
            if act.dim() > 2:
                act = act.squeeze(0).squeeze(0)
            elif act.dim() == 2:
                act = act.squeeze(0)
            act = act.detach()

            if name not in avg_activations:
                avg_activations[name] = act.clone()
            else:
                old_avg = avg_activations[name]
                if old_avg.shape != act.shape:
                    new_avg = torch.zeros_like(act)
                    num_to_copy = min(len(act), len(old_avg))
                    new_avg[:num_to_copy] = old_avg[:num_to_copy]
                    avg_activations[name] = new_avg

                diff = act - avg_activations[name]
                reg_loss = reg_loss + lambda_act * (diff ** 2).mean()
                avg_activations[name] = alpha * avg_activations[name] + (1 - alpha) * act

    return reg_loss

# overfittign detection
def detect_loss_divergence(train_loss, val_loss, growth_factor=1.05, patience=5):
    """
    Detects if validation loss starts diverging from training loss.

    Args:
        train_loss (list of float): Training loss per epoch.
        val_loss (list of float): Validation loss per epoch.
        growth_factor (float): Minimum growth ration between val_loss and train_loss to consider as divergence.
        patience (int): Number of epochs after which divergence is checked.

    Returns:
        bool: True if divergence is detected, False otherwise.
    """
    if len(train_loss) != len(val_loss) or len(train_loss) < patience:
        return False

    diverging = 0

    for i in range(1, len(train_loss)):

        # check if the gap is growing and if the val_loss is growing
        if val_loss[i]/train_loss[i] > growth_factor and (val_loss[i]- train_loss[i]) > (val_loss[i-1] - train_loss[i-1]) and val_loss[i] > val_loss[i-1]:
            diverging += 1

            # check if the gap is growing for a number of epochs
            if diverging >= patience:
                return True
        else:
            diverging = 0

    return False