import itertools
import torch
import torch.nn.functional as F
import pickle
import os
from copy import deepcopy
from src.eval.eval import eval_single_dataset
from src.models import get_classification_head, ImageClassifier, ImageEncoder
from src.datasets import get_dataloader, get_dataset, maybe_dictionarize
from src.finetune.continual_finetune import continual_finetune
from src.finetune.sabcd_finetune import sabcd_finetune
from src.utils.utils import cosine_lr


class EWC:
    """
    Elastic Weight Consolidation (EWC) implementation

    EWC prevents catastrophic forgetting by adding a regularization term to the loss function:
    L = L_task + λ * Σ F_i * (θ_i - θ*_i)^2

    where:
    - L_task is the loss for the current task
    - λ is the regularization strength
    - F_i is the Fisher information (indicating parameter importance) for parameter θ_i
    - θ*_i is the parameter value from the previous task
    """

    def __init__(self, model, dataset_loader, device, reg_lambda=1e5):
        """
        Initialize EWC

        Args:
            model: Neural network model
            dataset_loader: Data loader for computing Fisher information
            device: Computing device
            reg_lambda: EWC regularization strength
        """
        self.model = model
        self.device = device
        self.reg_lambda = reg_lambda

        # Store important parameters and Fisher information
        self.saved_params = {}
        self.fisher_info = {}

        # First compute Fisher information matrix and save important parameters
        self._first_task_setup(dataset_loader)

    def _first_task_setup(self, dataset_loader):
        """
        EWC setup for the first task, compute Fisher information and save current parameters

        Args:
            dataset_loader: Data loader for computing Fisher information
        """
        print("Initializing EWC Fisher information matrix...")

        # First save current model parameters as important parameters
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.saved_params[name] = param.data.clone()

        # Then compute Fisher information
        self._compute_fisher_information(dataset_loader)

        print("EWC initialization completed")

    def _compute_fisher_information(self, dataset_loader):
        """Compute the diagonal elements of the Fisher information matrix"""
        print("Computing Fisher information matrix...")

        # Reinitialize Fisher information (do not reset saved_params)
        self.fisher_info = {}

        # Set initial values for Fisher information matrix
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.fisher_info[name] = torch.zeros_like(param.data)

        # Ensure model is in training mode
        self.model.train()
        valid_samples = 0
        total_samples = 0

        # Ensure gradient computation is enabled for all parameters
        for param in self.model.parameters():
            param.requires_grad = True

        # Calculate number of batches, consistent with reference implementation
        num_samples = len(dataset_loader.dataset) if hasattr(dataset_loader, 'dataset') else 0
        n_samples_batches = (num_samples // dataset_loader.batch_size + 1) if num_samples > 0 \
            else (len(dataset_loader) // dataset_loader.batch_size)
        
        print(f"Starting to compute Fisher information matrix, planning to process {n_samples_batches} batches...")

        # Use itertools.islice to limit batch count
        for batch_idx, batch in enumerate(itertools.islice(dataset_loader, n_samples_batches)):
            try:
                batch = maybe_dictionarize(batch)
                inputs = batch["images"].to(self.device)
                total_samples += inputs.size(0)

                # Reset gradients for each batch
                self.model.zero_grad()

                # Forward propagation and loss calculation
                outputs = self.model(inputs)
                preds = outputs.argmax(1)
                loss = F.cross_entropy(outputs, preds)

                # Backpropagation
                loss.backward()

                # Accumulate Fisher information - multiply by batch size
                for name, param in self.model.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        fisher_value = param.grad.data.pow(2)
                        if not (torch.isnan(fisher_value).any() or torch.isinf(fisher_value).any()):
                            self.fisher_info[name] += fisher_value * inputs.size(0)

                valid_samples += inputs.size(0)

                if batch_idx % 10 == 0:
                    print(f"Fisher information computation progress: {batch_idx}/{n_samples_batches}", end="\r")

            except Exception as e:
                print(f"Error processing batch {batch_idx}: {str(e)}")
                continue

        # Normalize Fisher information
        if valid_samples > 0:
            print(f"\nValid samples: {valid_samples}/{total_samples}")
            for name in self.fisher_info:
                self.fisher_info[name] /= valid_samples

        print(f"\nFisher information computation completed, using {valid_samples} valid samples")

    def ewc_loss(self):
        """
        Compute EWC regularization loss, safely handling parameter shape mismatches

        Returns:
            torch.Tensor: EWC regularization loss
        """
        ewc_loss = 0
        matched_params = 0
        total_params = 0
        skipped_params = 0

        # No longer detect and print warnings about parameter changes

        for name, param in self.model.named_parameters():
            total_params += 1
            # Skip classification head parameters
            if any(key in name.lower() for key in ['classifier', 'head', 'fc', 'logit', 'output']):
                continue

            if name in self.fisher_info and name in self.saved_params:
                # Check if parameter shapes match
                if param.shape != self.saved_params[name].shape:
                    skipped_params += 1
                    continue

                # Calculate difference between current and saved parameters
                diff = (param - self.saved_params[name])

                # Calculate EWC loss
                ewc_loss += torch.sum(self.fisher_info[name] * diff.pow(2)) / 2
                matched_params += 1

        # Only print information when a significant number of parameters are skipped
        if skipped_params > 10:  # Only prompt when skipping many parameters
            print(f"Note: Skipped {skipped_params} parameters due to shape mismatch")

        return ewc_loss

    def update_fisher_and_save_params(self, dataset_loader):
        """
        Key function for task switching: first save current parameters as important parameters, then update Fisher information
        
        Args:
            dataset_loader: Data loader for the new task
        """
        print("Task switching - first saving current parameters as reference point...")

        # 1. Save current model parameters as reference point for next task
        current_params = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                current_params[name] = param.data.clone()

        # Save previous Fisher information
        old_fisher_info = deepcopy(self.fisher_info)

        # 2. Compute Fisher information for new task
        temp_fisher_info = {}
        
        # Ensure model is in training mode for Fisher computation
        self.model.train()
        valid_samples = 0
        total_samples = 0

        # Calculate number of batches, consistent with reference implementation
        num_samples = len(dataset_loader.dataset) if hasattr(dataset_loader, 'dataset') else 0
        n_samples_batches = (num_samples // dataset_loader.batch_size + 1) if num_samples > 0 \
            else (len(dataset_loader) // dataset_loader.batch_size)
            
        print(f"Starting to compute Fisher information matrix for new task, planning to process {n_samples_batches} batches...")

        # Initialize temporary Fisher information matrix
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                temp_fisher_info[name] = torch.zeros_like(param.data)

        # Use itertools.islice to limit batch count
        for batch_idx, batch in enumerate(itertools.islice(dataset_loader, n_samples_batches)):
            try:
                batch = maybe_dictionarize(batch)
                inputs = batch["images"].to(self.device)
                total_samples += inputs.size(0)

                self.model.zero_grad()
                outputs = self.model(inputs)
                preds = outputs.argmax(1)
                loss = F.cross_entropy(outputs, preds)
                loss.backward()

                # Accumulate Fisher information
                for name, param in self.model.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        fisher_value = param.grad.data.pow(2)
                        if not (torch.isnan(fisher_value).any() or torch.isinf(fisher_value).any()):
                            temp_fisher_info[name] += fisher_value * inputs.size(0)
                        else:
                            print(f"Warning: Fisher information for parameter {name} contains invalid values")

                valid_samples += inputs.size(0)

                if batch_idx % 10 == 0:
                    print(f"Fisher information computation progress: {batch_idx}/{n_samples_batches}", end="\r")

            except Exception as e:
                print(f"Error processing batch {batch_idx}: {str(e)}")
                continue

        # Normalize Fisher information
        if valid_samples > 0:
            print(f"\nValid samples: {valid_samples}/{total_samples}")
            for name in temp_fisher_info:
                temp_fisher_info[name] /= valid_samples

        # 3. Merge Fisher information and update parameter dictionary
        updated_fisher = {}
        for name in temp_fisher_info:
            # Check if it's a classification head parameter
            is_classifier = any(key in name.lower() for key in [
                                'classifier', 'head', 'fc', 'logit', 'output'])

            if is_classifier:
                # For classification head parameters, use only the latest Fisher information
                updated_fisher[name] = temp_fisher_info[name]
            elif name in old_fisher_info and old_fisher_info[name].shape == temp_fisher_info[name].shape:
                # For backbone network parameters with matching shapes, perform weighted averaging
                alpha = 0.5  
                updated_fisher[name] = alpha * temp_fisher_info[name] + \
                    (1 - alpha) * old_fisher_info[name]
            else:
                # For newly added parameters or shape mismatches, use new Fisher information
                updated_fisher[name] = temp_fisher_info[name]
                if name in old_fisher_info and old_fisher_info[name].shape != temp_fisher_info[name].shape:
                    print(
                        f"Parameter {name} shape changed: {old_fisher_info[name].shape} -> {temp_fisher_info[name].shape}")

        # 4. Update EWC state
        self.fisher_info = updated_fisher
        self.saved_params = current_params

        print(f"Task switching completed: Fisher information updated, processed {len(updated_fisher)} parameters")

    def save_ewc_state(self, save_path):
        """Save EWC state to support continuing experiments"""
        print(f"Saving EWC state to: {save_path}")
        ewc_state = {
            'fisher_info': {name: info.cpu() for name, info in self.fisher_info.items()},
            'saved_params': {name: param.cpu() for name, param in self.saved_params.items()},
            'reg_lambda': self.reg_lambda
        }

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as f:
            pickle.dump(ewc_state, f)
        print("EWC state saved successfully")

    def load_ewc_state(self, load_path, device):
        """Load EWC state to support continuing experiments"""
        if not os.path.exists(load_path):
            print(f"EWC state file not found: {load_path}")
            return False

        print(f"Loading EWC state from: {load_path}")
        try:
            with open(load_path, 'rb') as f:
                ewc_state = pickle.load(f)

            self.fisher_info = {name: info.to(
                device) for name, info in ewc_state['fisher_info'].items()}
            self.saved_params = {name: param.to(
                device) for name, param in ewc_state['saved_params'].items()}
            self.reg_lambda = ewc_state['reg_lambda']
            self.device = device

            # Print loaded state information
            params_count = len(self.saved_params)
            fisher_count = len(self.fisher_info)
            print(
                f"EWC state loaded successfully: loaded {params_count} reference parameters and {fisher_count} Fisher information items")

            # Check for mismatched parameters
            model_params = set(
                [name for name, _ in self.model.named_parameters() if _.requires_grad])
            saved_params = set(self.saved_params.keys())
            missing_params = model_params - saved_params
            extra_params = saved_params - model_params

            if missing_params:
                print(f"Warning: Found {len(missing_params)} parameters in model but not in saved EWC state")
            if extra_params:
                print(f"Warning: Found {len(extra_params)} parameters in EWC state but not in current model")

            return True
        except Exception as e:
            print(f"Error loading EWC state: {e}")
            return False


def apply_ewc_regularization(model, ewc_regularizer, task_loss):
    """
    Apply EWC regularization to the loss function

    Args:
        model: Model
        ewc_regularizer: EWC regularizer
        task_loss: Current task loss

    Returns:
        total_loss: Total loss including EWC regularization
        ewc_loss_value: EWC loss value
    """
    if ewc_regularizer is not None:
        # Calculate raw EWC loss
        ewc_loss = ewc_regularizer.ewc_loss()

        # Apply regularization coefficient
        scaled_ewc_loss = ewc_regularizer.reg_lambda * ewc_loss
        scaled_ewc_value = scaled_ewc_loss.item()

        # Merge losses
        total_loss = task_loss + scaled_ewc_loss

        # Simplified loss information printing
        return total_loss, scaled_ewc_value
    else:
        return task_loss, 0.0


def ewc_enhanced_finetune(args, train_dataset, starting_model_path, output_path, ewc_regularizer=None, use_sabcd=True):
    """
    Fine-tuning function supporting EWC regularization

    Args:
        args: Training parameters
        train_dataset: Training dataset
        starting_model_path: Starting model path
        output_path: Output path
        ewc_regularizer: EWC regularizer
        use_sabcd: Whether to use SABCD fine-tuning
    """
    if ewc_regularizer is None:
        # If no EWC regularizer, use original fine-tuning methods
        if use_sabcd:
            return sabcd_finetune(args, train_dataset, starting_model_path, output_path)
        else:
            return continual_finetune(args, train_dataset, starting_model_path, output_path)
    else:
        # Use EWC-enhanced fine-tuning
        print(f"Using EWC-enhanced fine-tuning method (λ={ewc_regularizer.reg_lambda})")

        # Load starting model
        model = ImageEncoder(args.model)
        model.load_state_dict(torch.load(
            starting_model_path, map_location=args.device))
        model = model.to(args.device)

        # Create full model (including classification head)
        classification_head = get_classification_head(args, train_dataset)
        full_model = ImageClassifier(model, classification_head)
        full_model.freeze_head()
        full_model = full_model.to(args.device)

        # Update model reference in EWC regularizer
        ewc_regularizer.model = full_model

        # Execute EWC-enhanced fine-tuning
        return ewc_finetune_with_regularization(args, train_dataset, full_model, output_path, ewc_regularizer, use_sabcd)


def ewc_finetune_with_regularization(args, train_dataset, model, output_path, ewc_regularizer, use_sabcd=True):
    """
    Fine-tuning method combining EWC regularization
    """
    print(f"Starting EWC-enhanced fine-tuning training (λ={ewc_regularizer.reg_lambda})")

    # Prepare data
    preprocess_fn = model.train_preprocess
    dataset = get_dataset(
        train_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size if hasattr(args, 'batch_size') else 64,
    )
    data_loader = get_dataloader(
        dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(dataset.train_loader)
    loader = data_loader

    # Use cross-entropy loss function
    loss_fn = F.cross_entropy

    # Set optimizer
    lr = args.lr if hasattr(args, 'lr') else 1e-5
    wd = args.wd if hasattr(args, 'wd') else 0.1
    params = [p for p in model.parameters() if p.requires_grad]

    if use_sabcd:
        from src.optimizers.sabcd import SABCD
        optimizer = SABCD(params, lr=lr, weight_decay=wd)
        print("Using SABCD optimizer for fine-tuning")
    else:
        optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
        print("Using AdamW optimizer for fine-tuning")

    # Learning rate scheduler
    warmup_length = args.warmup_length if hasattr(
        args, 'warmup_length') else 0.1
    num_grad_accumulation = args.num_grad_accumulation if hasattr(
        args, 'num_grad_accumulation') else 2

    # Dataset-specific epochs mapping
    epochs_map = {
        "Cars": 35, "DTD": 76, "EuroSAT": 12, "GTSRB": 11,
        "MNIST": 5, "RESISC45": 15, "SUN397": 14, "SVHN": 4,
        "CIFAR10": 6, "CIFAR100": 6, "STL10": 60, "Food101": 4,
        "Flowers102": 147, "FER2013": 10, "PCAM": 1, "OxfordIIITPet": 82,
        "RenderedSST2": 39, "EMNIST": 2, "FashionMNIST": 5, "KMNIST": 5,
    }

    base_name = train_dataset.replace("Val", "")
    epochs = epochs_map.get(base_name, 10)
    print(f"Setting training epochs to {epochs} for dataset {base_name}")

    scheduler = cosine_lr(
        optimizer, lr, warmup_length, epochs * num_batches // num_grad_accumulation,
    )

    # Training loop
    best_model = None
    best_accuracy = -1.0

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_task_loss = 0.0
        epoch_ewc_loss = 0.0
        num_epoch_batches = 0

        for i, batch in enumerate(loader):
            step = (i // num_grad_accumulation + epoch *
                    num_batches // num_grad_accumulation)

            batch = maybe_dictionarize(batch)
            inputs = batch["images"].to(args.device)
            labels = batch["labels"].to(args.device)

            # Forward propagation
            outputs = model(inputs)
            task_loss = loss_fn(outputs, labels)

            # Add EWC regularization
            total_loss, ewc_loss_value = apply_ewc_regularization(
                model, ewc_regularizer, task_loss)

            # Backpropagation
            total_loss.backward()

            if (i + 1) % num_grad_accumulation == 0:
                scheduler(step)
                torch.nn.utils.clip_grad_norm_(params, 1.0)
                optimizer.step()
                optimizer.zero_grad()

            # Accumulate losses
            epoch_loss += total_loss.item()
            epoch_task_loss += task_loss.item()
            epoch_ewc_loss += ewc_loss_value
            num_epoch_batches += 1

            if i % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}, Batch {i+1}/{num_batches}, "
                      f"Task Loss: {task_loss.item():.6f}, "
                      f"EWC Loss: {ewc_loss_value:.6f}, "
                      f"Total Loss: {total_loss.item():.6f}", end="\r")

        # Evaluate model (every 5 epochs or last epoch)
        if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
            acc = eval_single_dataset(
                model.image_encoder, train_dataset, args)['top1']
            print(f"Epoch {epoch+1} validation accuracy: {acc*100:.2f}%")

            if acc > best_accuracy:
                best_accuracy = acc
                best_model = model.image_encoder.state_dict().copy()
                print(f"Better model found, accuracy: {best_accuracy*100:.2f}%")

    # Save best model
    if best_model is not None:
        torch.save(best_model, output_path)
        print(f"Saved best model (accuracy: {best_accuracy*100:.2f}%) to {output_path}")

    print(f"EWC fine-tuning completed, best accuracy: {best_accuracy*100:.2f}%")
    return best_accuracy