import torch
import torch.nn as nn
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import torch.nn.functional as F
from typing import List


def accumulate_gradients(model, teacher_idx, cumulative_grads, active_flags, threshold):
    """
    Accumulate gradients for each parameter until they fall below threshold.
    
    Args:
        model: The neural network model
        teacher_idx: Index of current teacher
        cumulative_grads: Dictionary storing cumulative gradients
        active_flags: Dictionary tracking which parameters are still accumulating
        threshold: Threshold below which to stop accumulating (None = accumulate all)
    """
    if teacher_idx not in cumulative_grads:
        cumulative_grads[teacher_idx] = {}
        active_flags[teacher_idx] = {}
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            # Initialize if first time seeing this parameter
            if name not in cumulative_grads[teacher_idx]:
                cumulative_grads[teacher_idx][name] = torch.zeros_like(param.grad)
                active_flags[teacher_idx][name] = True
            
            # Only accumulate if still active
            if active_flags[teacher_idx][name]:
                cumulative_grads[teacher_idx][name] += param.grad.clone()
                
                # Check threshold: max gradient per parameter must be below threshold
                if threshold is not None:
                    max_grad = torch.max(torch.abs(param.grad)).item()
                    if max_grad < threshold:
                        active_flags[teacher_idx][name] = False


def get_layer_param_groups(model):
    """
    Generate layer parameter groups for SAM sharpness computation.
    
    Args:
        model: The neural network model
        
    Returns:
        dict: Dictionary mapping layer names to lists of parameters
    """
    layer_param_groups = {}
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
            # Get all parameters for this layer (weights and biases)
            layer_params = list(module.parameters())
            if layer_params:  # Only add if layer has parameters
                layer_param_groups[name] = layer_params
    
    return layer_param_groups


def sam_sharpness_per_layer(model, loss_fn, batch, layer_param_groups, rho=0.05):
    """
    Compute SAM sharpness per layer using the method from
    
    Args:
        model: The neural network model
        loss_fn: Loss function
        batch: Input batch (x, y)
        layer_param_groups: Dictionary mapping layer names to parameter lists
        rho: Perturbation radius
        
    Returns:
        dict: Dictionary mapping layer names to sharpness values (loss increase)
    """
    # Store original training mode
    original_mode = model.training
    model.eval()
    
    x, y = batch
    
    # 1) Compute base loss and gradients
    for p in model.parameters(): 
        p.grad = None
    
    # Handle model output format (some models return (logits, pre_activations))
    model_output = model(x)
    if isinstance(model_output, tuple):
        logits, _ = model_output
    else:
        logits = model_output
        
    base_loss = loss_fn(logits, y)
    base_loss.backward()
    
    results = {}
    
    # 2) Per-layer perturbation and sharpness computation
    for layer_name, params in layer_param_groups.items():
        # Check if any parameter in this layer has gradients
        layer_grads = [p.grad for p in params if p.grad is not None]
        if not layer_grads:
            results[layer_name] = 0.0
            continue
            
        # Compute gradient norm for this layer
        gnorm = torch.sqrt(sum((grad**2).sum() for grad in layer_grads) + 1e-12)
        if gnorm == 0:
            results[layer_name] = 0.0
            continue
        
        # Store original parameter values
        original_params = []
        for p in params:
            if p.grad is not None:
                original_params.append(p.data.clone())
            else:
                original_params.append(None)
        
        # Apply perturbation: θ + ρ * ∇L(θ) / ||∇L(θ)||
        for p, orig_val in zip(params, original_params):
            if p.grad is not None:
                p.data.add_(rho * p.grad / gnorm)
        
        # Compute perturbed loss
        for p in model.parameters():
            p.grad = None
            
        model_output_perturbed = model(x)
        if isinstance(model_output_perturbed, tuple):
            logits_perturbed, _ = model_output_perturbed
        else:
            logits_perturbed = model_output_perturbed
            
        loss_perturbed = loss_fn(logits_perturbed, y).item()
        results[layer_name] = loss_perturbed - base_loss.item()
        
        # Restore original parameter values
        for p, orig_val in zip(params, original_params):
            if orig_val is not None:
                p.data.copy_(orig_val)
    
    # Restore original training mode
    model.train(original_mode)
    
    return results


def train_dual_heads(
        student: nn.Module, 
        teachers: List[pl.LightningModule], 
        datamodules: List[pl.LightningDataModule], 
        epochs_list: List[int],
        save_dir: str,
        lr = 0.01,
        momentum = 0.0,
        device = "cuda:0",
        gradient_threshold = 1e-6,
        min_param_value = None):
        
        
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(student.parameters(), lr=lr, momentum=momentum)

        student.to(device)
        
        # Initialize gradient tracking
        cumulative_gradients = {}  # {teacher_idx: {param_name: cumulative_grad_tensor}}
        gradient_active = {}       # {teacher_idx: {param_name: bool}} - track if still accumulating
        
        # Initialize sharpness tracking
        sharpness_data = {}  # {teacher_idx: {epoch: {layer_name: sharpness_value}}}
        
        # Initialize residuals tracking
        residuals_data = {}
        
        # Initialize pre-activations tracking
        pre_activations_data = {}  # {teacher_idx: {epoch: {layer_name: mean_tensor}}}
        
        # Initialize co-activation tracking
        co_activations_data = {}
        
        h0_metric_list, h1_metric_list = [], []
        h0_test_loss_list, h1_test_loss_list = [], []
        h0_teacher_vs_true_list, h1_teacher_vs_true_list = [], []
        h0_student_vs_teacher_list, h1_student_vs_teacher_list = [], []
        
        for j,teacher in enumerate(teachers):
            
            print(f"training on teacher {j}...")
            teacher.to(device)
            train_dataloader = datamodules[j].train_dataloader()
            student.update_head(j)
            epochs = epochs_list[j]
            cumulative_residuals = None
            
            # Generate layer parameter groups for sharpness computation
            layer_param_groups = get_layer_param_groups(student)
            
            for epoch in range(epochs):  # loop over the dataset multiple times
                student.train()
                running_loss = 0.0
                teacher_accuracy_list = []
                
                epoch_co_activations = None         # concat over samples in this epoch
                epoch_count = 0
                
                # Store the last batch for sharpness computation and optimized metrics
                last_batch = None
                last_batch_pre_activations = None
                last_batch_residuals = None

                for i, data in enumerate(train_dataloader, 0):

                    #labels
                    inputs, true_labels = data[0].to(device), data[1].to(device)
                    teacher_out = teacher.forward(data_input=inputs)
                    teacher_labels = torch.argmax(torch.softmax(teacher_out, dim=-1), dim=-1)
                    
                    # training step
                    optimizer.zero_grad()
                    student_output = student(inputs)
                    # Student returns (logits, pre_activations) - we need both
                    if isinstance(student_output, tuple):
                        logits, pre_activations = student_output
                    else:
                        logits = student_output
                        pre_activations = None
                    loss = criterion(logits, teacher_labels)
                    loss.backward()
                    
                    # Accumulate gradients
                    accumulate_gradients(student, j, cumulative_gradients, gradient_active, gradient_threshold)
                    
                    # Store pre-activations from last batch (optimized approach)
                    if pre_activations is not None:
                        last_batch_pre_activations = pre_activations
                    
                    # Concatenate co-activation
                    if pre_activations is not None:
                        if epoch_co_activations is None:
                            epoch_co_activations = {}
                            for layer_name, bool_tensor in pre_activations.items():
                                # Convert boolean to float16 and sum across batch
                                epoch_co_activations[layer_name] = bool_tensor.float().to(torch.float16)
                        else:
                            for layer_name, bool_tensor in pre_activations.items():
                                # Convert boolean to float16 and sum across batch
                                epoch_co_activations[layer_name] = torch.cat([
                                    epoch_co_activations[layer_name],
                                    bool_tensor.float().to(torch.float16)], dim=0)
                                        
                    
                    # Store residuals from last batch (optimized approach)
                    num_classes = logits.shape[-1]
                    probs = F.softmax(logits, dim=-1)
                    y = F.one_hot(teacher_labels, num_classes=num_classes).float()
                    last_batch_residuals = probs - y
                    epoch_count += last_batch_residuals.size(0)
                    
                    optimizer.step()

                    # Store last batch for sharpness computation
                    last_batch = (inputs, teacher_labels)

                    # print statistics
                    teacher_correct = (true_labels == teacher_labels).sum()
                    teacher_accuracy_list.append(teacher_correct.float()/true_labels.shape[0])

                    running_loss += loss.item()
                    if i % 100 == 99:
                        print(f"{epoch + 1}, {i + 1:5d} loss: {running_loss / 100:.3f}")
                        running_loss = 0.0
                
                
                # Compute residuals from last batch (optimized approach)
                if last_batch_residuals is not None:
                    E_K_epoch = last_batch_residuals.mean(dim=0)  # Mean across batch dimension
                    if cumulative_residuals is None:
                        cumulative_residuals = E_K_epoch
                    else:
                        cumulative_residuals += E_K_epoch
                
                # Compute pre-activations from last batch (optimized approach)
                if last_batch_pre_activations is not None:
                    epoch_final_pre_activations = {}
                    for layer_name, bool_tensor in last_batch_pre_activations.items():
                        # Convert boolean to float16 and compute mean across batch
                        epoch_final_pre_activations[layer_name] = bool_tensor.float().to(torch.float16).mean(dim=0)
                    
                    # Store in pre-activations data structure
                    if j not in pre_activations_data:
                        pre_activations_data[j] = {}
                    pre_activations_data[j][epoch] = epoch_final_pre_activations
                    
                # Compute co-activations for this epoch
                if epoch_co_activations is not None:
                    
                    epoch_covariance_activations = {}
                    for layer_name, z in epoch_co_activations.items():
                        N = z.shape[0]
                        mu = z.mean(dim=0)                              # [d]
                        C_mle = (z.T @ z) / N - torch.outer(mu, mu)     # centered covariance (ML)
                        C = C_mle * (N / (N - 1)) if N > 1 else C_mle   # unbiased (optional)
                        C = 0.5*(C + C.T)                               # symmetrize for stability
                        epoch_covariance_activations[layer_name] = C
                    
                    # Store in co-activations data structure
                    if j not in co_activations_data:
                        co_activations_data[j] = {}
                    co_activations_data[j][epoch] = epoch_covariance_activations    
                
                
                # Compute sharpness for this epoch
                if last_batch is not None:
                    epoch_sharpness = sam_sharpness_per_layer(
                        model=student,
                        loss_fn=criterion,
                        batch=last_batch,
                        layer_param_groups=layer_param_groups,
                        rho=0.05
                    )
                    
                    # Store sharpness data
                    if j not in sharpness_data:
                        sharpness_data[j] = {}
                    sharpness_data[j][epoch] = epoch_sharpness
                    
                    # Print sharpness information
                    sharpness_str = ", ".join([f"{layer}: {sharpness:.6f}" for layer, sharpness in epoch_sharpness.items()])
                    print(f"Epoch {epoch + 1} sharpness: {sharpness_str}")
                
                
                # accuracy
                teacher_accuracy = sum(teacher_accuracy_list)/len(teacher_accuracy_list)
                h0_accuracy_dict = get_head_accuracy(model=student, teacher=teachers[0], dataloader=datamodules[0].test_dataloader(), head_flag=0)
                h1_accuracy_dict = get_head_accuracy(model=student, teacher=teachers[1], dataloader=datamodules[1].test_dataloader(), head_flag=1)
                
                # head losses
                h0_loss = get_head_loss(model=student, dataloader=datamodules[0].test_dataloader(), head_flag=0)
                h1_loss = get_head_loss(model=student, dataloader=datamodules[1].test_dataloader(), head_flag=1)

                print(f"epoch: {epoch + 1} teacher accuracy: {teacher_accuracy:.3f}")
                print(f"head 0: student_vs_true = {h0_accuracy_dict['student_vs_true']:.3f}, teacher_vs_true = {h0_accuracy_dict['teacher_vs_true']:.3f}, student_vs_teacher = {h0_accuracy_dict['student_vs_teacher']:.3f}, loss = {h0_loss:.3f}")
                print(f"head 1: student_vs_true = {h1_accuracy_dict['student_vs_true']:.3f}, teacher_vs_true = {h1_accuracy_dict['teacher_vs_true']:.3f}, student_vs_teacher = {h1_accuracy_dict['student_vs_teacher']:.3f}, loss = {h1_loss:.3f}")
                
                # append epoch quantities (keeping backward compatibility by using student_vs_true)
                teacher_accuracy_list.append(teacher_accuracy)
                h0_metric_list.append(h0_accuracy_dict['student_vs_true'].cpu().numpy())
                h1_metric_list.append(h1_accuracy_dict['student_vs_true'].cpu().numpy())
                h0_test_loss_list.append(h0_loss.cpu().numpy())
                h1_test_loss_list.append(h1_loss.cpu().numpy())
                
                # append new accuracy metrics
                h0_teacher_vs_true_list.append(h0_accuracy_dict['teacher_vs_true'].cpu().numpy())
                h1_teacher_vs_true_list.append(h1_accuracy_dict['teacher_vs_true'].cpu().numpy())
                h0_student_vs_teacher_list.append(h0_accuracy_dict['student_vs_teacher'].cpu().numpy())
                h1_student_vs_teacher_list.append(h1_accuracy_dict['student_vs_teacher'].cpu().numpy())
                
                # save epoch checkpoint
                checkpoint = {
                    "epoch": epoch,
                    "model_state_dict": student.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "train_loss": running_loss,
                    "h0_test_loss": h0_loss,
                    "h1_test_loss": h1_loss
                }
                
                ckpt_name = f"task_{j}_epoch_{epoch}.pth"
                torch.save(checkpoint, save_dir+ckpt_name)
                
            residuals_data[j] = cumulative_residuals
        
        # Save residuals
        residuals_filename = "_cumulative_residuals.pth"
        torch.save(residuals_data, save_dir + residuals_filename)
        print(f"Cumulative residuals saved to {save_dir + residuals_filename}")
        
        # Save cumulative gradients to disk
        gradient_data = {
            'cumulative_gradients': cumulative_gradients,
            'gradient_active': gradient_active,
            'threshold_used': gradient_threshold,
            'num_teachers': len(teachers),
            'epochs_per_teacher': epochs_list
        }
        gradients_filename = "_cumulative_gradients.pth"
        torch.save(gradient_data, save_dir + gradients_filename)
        print(f"Cumulative gradients saved to {save_dir + gradients_filename}")
        
        # Save pre-activations to disk
        if pre_activations_data:
            pre_activations_filename = "_pre_activations.pth"
            torch.save(pre_activations_data, save_dir + pre_activations_filename)
            print(f"Pre-activations saved to {save_dir + pre_activations_filename}")
            
        
        # Save co-activations to disk
        if co_activations_data:
            co_activations_filename = "_co_activations.pth"
            torch.save(co_activations_data, save_dir + co_activations_filename)
            print(f"Co-activations saved to {save_dir + co_activations_filename}")
        
        # Save sharpness data to disk
        if sharpness_data:
            sharpness_filename = "_sharpness.pth"
            torch.save(sharpness_data, save_dir + sharpness_filename)
            print(f"Sharpness data saved to {save_dir + sharpness_filename}")
                
        return pd.DataFrame.from_dict({
            "h0_metric": h0_metric_list, 
            "h1_metric": h1_metric_list,
            "h0_test_loss": h0_test_loss_list, 
            "h1_test_loss": h1_test_loss_list,
            "h0_teacher_vs_true": h0_teacher_vs_true_list,
            "h1_teacher_vs_true": h1_teacher_vs_true_list,
            "h0_student_vs_teacher": h0_student_vs_teacher_list,
            "h1_student_vs_teacher": h1_student_vs_teacher_list
            })
    
    
def get_head_accuracy(
    model: nn.Module, 
    teacher: nn.Module,
    dataloader, 
    head_flag: int,
    device = "cuda:0"):
    
    assert head_flag in [0,1]
    
    student_vs_true_list = []
    teacher_vs_true_list = []
    student_vs_teacher_list = []
    
    model.eval()
    teacher.eval()
    # Ensure teacher is on the correct device
    teacher.to(device)
    
    with torch.no_grad():
        
        for batch in dataloader:
            X, Y = batch[0].to(device), batch[1].to(device)
            
            # Get student predictions
            if head_flag == 0:
                head_output = model.forward_h0(X)
            elif head_flag == 1:
                head_output = model.forward_h1(X)
            
            # Handle new return format (output, pre_activations) - ignore pre_activations for evaluation
            if isinstance(head_output, tuple):
                student_out, _ = head_output
            else:
                student_out = head_output
            
            student_pred = torch.argmax(torch.softmax(student_out, dim=-1), dim=-1)
            
            # Get teacher predictions
            teacher_out = teacher.forward(data_input=X)
            teacher_pred = torch.argmax(torch.softmax(teacher_out, dim=-1), dim=-1)
            
            # Calculate accuracies
            # Student vs true labels
            student_vs_true_correct = (Y == student_pred).sum()
            student_vs_true_list.append(student_vs_true_correct.float() / Y.shape[0])
            
            # Teacher vs true labels
            teacher_vs_true_correct = (Y == teacher_pred).sum()
            teacher_vs_true_list.append(teacher_vs_true_correct.float() / Y.shape[0])
            
            # Student vs teacher predictions
            student_vs_teacher_correct = (teacher_pred == student_pred).sum()
            student_vs_teacher_list.append(student_vs_teacher_correct.float() / Y.shape[0])
    
    return {
        "student_vs_true": sum(student_vs_true_list) / len(student_vs_true_list),
        "teacher_vs_true": sum(teacher_vs_true_list) / len(teacher_vs_true_list),
        "student_vs_teacher": sum(student_vs_teacher_list) / len(student_vs_teacher_list)
    }


def get_head_loss(
    model: nn.Module, 
    dataloader, 
    head_flag: int,
    criterion = nn.CrossEntropyLoss(),
    device = "cuda:0"):
    
    assert head_flag in [0,1]
    
    loss_list = []
    model.eval()
    
    with torch.no_grad():
        
        for batch in dataloader:
            X,Y = batch[0].to(device), batch[1].to(device)
            
            if head_flag == 0:
                head_output = model.forward_h0(X)
            elif head_flag == 1:
                head_output = model.forward_h1(X)
            
            # Handle new return format (output, pre_activations) - ignore pre_activations for evaluation
            if isinstance(head_output, tuple):
                out, _ = head_output
            else:
                out = head_output
            
            loss = criterion(out, Y)
            loss_list.append(loss)
            
    return sum(loss_list)
