import os
import pickle
import torch
import torch.nn.functional as F
import random
from collections import defaultdict

from src.eval.eval import eval_single_dataset
from src.models import get_classification_head, ImageClassifier, ImageEncoder
from src.datasets import get_dataloader, get_dataset, maybe_dictionarize
from src.finetune.continual_finetune import continual_finetune
from src.finetune.sabcd_finetune import sabcd_finetune
from src.utils.utils import cosine_lr


class ExperienceReplay:
    """
    Experience Replay (ER) implementation

    Reduces catastrophic forgetting by storing and replaying samples from past tasks
    """

    def __init__(self, model, device, total_capacity=256, sampling_ratio=0.1, max_samples_per_task=64,
                 sampling_strategy="random", per_class_quota=None):
        """
        Initialize Experience Replay

        Args:
            model: Model object (for fine-tuning)
            device: Computing device
            total_capacity: Total memory buffer size
            sampling_ratio: Replay sample ratio
            sampling_strategy: Sample selection strategy
            per_class_quota: Quota per class
            max_samples_per_task: Maximum samples per task
        """
        self.model = model  # May be set later during use
        self.device = device
        self.total_capacity = total_capacity
        self.sampling_ratio = sampling_ratio
        self.sampling_strategy = sampling_strategy
        self.per_class_quota = per_class_quota
        self.max_samples_per_task = max_samples_per_task  # Maximum samples per task

        # Create internal memory buffer - fixed to correct parameter names
        self.memory_buffer = MemoryBuffer(
            total_capacity=self.total_capacity,  # Set to 256 total capacity
            max_samples_per_task=self.max_samples_per_task,  # Maximum 64 samples per task
            sample_size=int(self.total_capacity *
                            self.sampling_ratio),  # Sampling size each time
            device=device
        )

        # Store classification heads for each task
        self.task_heads = {}  # Task name -> Classification head

    def update_memory(self, dataset_loader, task_name):
        """Update memory buffer by adding samples from current task"""
        collect_samples_for_memory(
            self.memory_buffer,
            task_name,
            dataset_loader,
            self.device,
            max_samples=self.memory_buffer.max_samples_per_task  # Use max_samples_per_task parameter
        )

    def get_replay_samples(self, current_task=None):
        """Get replay samples, optionally excluding current task"""
        excluded_tasks = [current_task] if current_task else None
        return self.memory_buffer.get_samples(excluded_tasks=excluded_tasks)

    def register_task_head(self, task_name, classification_head):
        """Register task-specific classification head"""
        self.task_heads[task_name] = classification_head
        print(f"Registered classification head for task {task_name}")

    def save_memory_state(self, save_path):
        """Save memory state"""
        try:
            # Create directory (if it doesn't exist)
            os.makedirs(os.path.dirname(
                os.path.abspath(save_path)), exist_ok=True)

            # Use temporary file mechanism for safe saving
            temp_path = save_path + '.tmp'

            # Save configuration and data
            save_data = {
                "total_capacity": self.memory_buffer.total_capacity,
                "max_samples_per_task": self.memory_buffer.max_samples_per_task,
                "sample_size": self.memory_buffer.sample_size,
                "task_sample_counts": dict(self.memory_buffer.task_sample_counts)
            }

            # Save buffer samples to copy
            buffer_copy = {}
            for task, samples in self.memory_buffer.buffer.items():
                buffer_copy[task] = [
                    {
                        "image": sample["image"].cpu(),
                        "label": sample["label"].cpu(),
                        "task_name": sample["task_name"]
                    }
                    for sample in samples
                ]
            save_data["buffer"] = buffer_copy

            # Save to temporary file
            with open(temp_path, 'wb') as f:
                pickle.dump(save_data, f)
                f.flush()
                os.fsync(f.fileno())  # Ensure data is written to disk

            # Safely replace with final file
            if os.path.exists(save_path):
                os.replace(temp_path, save_path)  # Atomic replacement
            else:
                os.rename(temp_path, save_path)

            print(f"Memory buffer safely saved to {save_path}")
            return True
        except Exception as e:
            print(f"Error saving memory state: {e}")
            import traceback
            traceback.print_exc()
            return False


    def load_memory_state(self, load_path, device=None):
        """Load memory state"""
        if device:
            self.device = device
            self.memory_buffer.device = device

        # Try standard loading
        if os.path.exists(load_path):
            try:
                with open(load_path, 'rb') as f:
                    save_data = pickle.load(f)

                # Restore configuration
                self.memory_buffer.total_capacity = save_data["total_capacity"]
                self.memory_buffer.max_samples_per_task = save_data["max_samples_per_task"]
                self.memory_buffer.sample_size = save_data["sample_size"]

                # Restore data
                self.memory_buffer.buffer = {}
                for task, samples in save_data["buffer"].items():
                    self.memory_buffer.buffer[task] = [
                        {
                            "image": sample["image"],  # Keep on CPU temporarily
                            "label": sample["label"],   # Keep on CPU temporarily
                            # Ensure task_name field exists
                            "task_name": sample.get("task_name", task)
                        }
                        for sample in samples
                    ]

                self.memory_buffer.task_sample_counts = defaultdict(
                    int, save_data["task_sample_counts"])

                # Print loading status
                print(f"Loaded memory buffer from {load_path}")
                for task, count in self.memory_buffer.task_sample_counts.items():
                    print(f"  - Task {task}: {count} samples")

                return True

            except Exception as e:
                print(f"Error loading memory buffer: {e}")
                return False
        else:
            print(f"Buffer file not found: {load_path}")
            return False


class MemoryBuffer:
    """
    Memory buffer for Experience Replay, dynamically balances sample allocation
    """

    def __init__(self, total_capacity=256, max_samples_per_task=64, sample_size=32, device="cuda"):
        """
        Initialize memory buffer

        Args:
            total_capacity: Total capacity (shared among all tasks)
            max_samples_per_task: Maximum samples per task
            sample_size: Number of samples per sampling
            device: Computing device
        """
        self.total_capacity = total_capacity
        self.max_samples_per_task = max_samples_per_task
        self.sample_size = sample_size
        self.device = device
        self.buffer = {}  # Task name -> Sample list
        self.task_sample_counts = defaultdict(int)

    def add_samples(self, task_name, samples):
        """
        Add samples to buffer and rebalance all tasks' samples if necessary

        Args:
            task_name: Task name
            samples: List of samples to add
        """
        # Add task identifier to samples
        for sample in samples:
            if "task_name" not in sample:
                sample["task_name"] = task_name

        # If task not in buffer, initialize list
        if task_name not in self.buffer:
            self.buffer[task_name] = []

        # Add new samples, limit current task's samples
        current_samples = self.buffer[task_name]
        new_samples = samples

        # Calculate current task's sample limit
        samples_limit = min(self.max_samples_per_task, self.total_capacity)

        # Calculate total samples after merging
        total_task_samples = len(current_samples) + len(new_samples)

        # If current task samples exceed limit, crop samples
        if total_task_samples > samples_limit:
            # Calculate number of old samples to keep
            keep_old = max(0, samples_limit - len(new_samples))
            if keep_old > 0:
                current_samples = current_samples[-keep_old:]
            else:
                current_samples = []

            # If new samples exceed limit, keep only latest part
            if len(new_samples) > samples_limit:
                new_samples = new_samples[-samples_limit:]

        # Update current task's samples
        self.buffer[task_name] = current_samples + new_samples
        self.task_sample_counts[task_name] = len(self.buffer[task_name])

        # Calculate total samples for all tasks
        total_samples = sum(len(samples) for samples in self.buffer.values())

        # If total samples exceed total capacity, rebalance
        if total_samples > self.total_capacity:
            self.rebalance_samples()

        print(
            f"Task {task_name} current buffer size: {self.task_sample_counts[task_name]}/{self.total_capacity}")
        print(
            f"Total samples: {sum(self.task_sample_counts.values())}/{self.total_capacity}")

    def rebalance_samples(self):
        """Rebalance samples for all tasks"""
        num_tasks = len(self.buffer)
        if num_tasks == 0:
            return

        # Calculate average quota per task (rounded down)
        base_quota = self.total_capacity // num_tasks

        # Calculate number of tasks needing extra samples (due to rounding down)
        remainder = self.total_capacity - (base_quota * num_tasks)

        # Sort task names (by addition order)
        task_names = list(self.buffer.keys())

        # Allocate quotas for each task
        quotas = {}
        for i, task in enumerate(task_names):
            # Last task gets all remaining quotas
            if i == len(task_names) - 1:
                quotas[task] = base_quota + remainder
            else:
                quotas[task] = base_quota

        # Apply reallocated quotas
        for task, quota in quotas.items():
            if len(self.buffer[task]) > quota:
                # Use random sampling strategy
                if quota > 0:  # Ensure quota > 0
                    self.buffer[task] = random.sample(self.buffer[task], quota)
                else:
                    self.buffer[task] = []
                self.task_sample_counts[task] = len(self.buffer[task])

        # Print rebalanced status
        print("\nMemory buffer rebalancing completed:")
        for task, count in self.task_sample_counts.items():
            quota = quotas.get(task, 0)
            print(f"  - Task {task}: {count}/{quota} samples")


def collect_samples_for_memory(memory_buffer, task_name, data_loader, device, max_samples=64):
    """Collect samples for specified task into memory buffer"""
    print(f"Collecting samples for task {task_name} into memory buffer...")

    samples = []
    total_collected = 0

    # Collect samples using data loader
    for i, batch in enumerate(data_loader):
        if total_collected >= max_samples:
            break

        batch = maybe_dictionarize(batch)
        inputs = batch["images"]  # Do not move to device immediately
        labels = batch["labels"]  # Do not move to device immediately

        # Calculate number of samples to collect
        batch_size = inputs.size(0)
        collect_count = min(batch_size, max_samples - total_collected)

        # Collect samples
        for j in range(collect_count):
            # Store CPU copy of sample
            sample = {
                "image": inputs[j].cpu(),
                "label": labels[j].cpu(),
                "task_name": task_name  # Explicitly store task name
            }
            samples.append(sample)

        total_collected += collect_count

        if i % 10 == 0:
            print(f"Collected {total_collected}/{max_samples} samples", end="\r")

    # Add to memory buffer
    memory_buffer.add_samples(task_name, samples)
    stored_count = memory_buffer.task_sample_counts[task_name]
    print(f"\nCollected {len(samples)} samples for task {task_name}, actually stored {stored_count} samples")


def process_replay_batch_by_task(replay_samples, device, task_heads, image_encoder, args=None, max_samples=16):
    """
    Process replay samples into batches by task, each task uses corresponding classification head, automatically creates missing heads
    
    Args:
        replay_samples: Replay sample list
        device: Computing device
        task_heads: Task classification head dictionary
        image_encoder: Image encoder
        args: Parameter object (for creating classification heads)
        max_samples: Maximum samples per task
    """
    # Group samples by task
    samples_by_task = defaultdict(list)
    for sample in replay_samples:
        task_name = sample["task_name"]
        samples_by_task[task_name].append(sample)

    # Process samples for each task
    batch_by_task = {}
    for task_name, task_samples in samples_by_task.items():
        # If samples exceed limit, randomly select subset
        if len(task_samples) > max_samples:
            task_samples = random.sample(task_samples, max_samples)

        # Ensure task's classification head exists, create if not
        if task_name not in task_heads:
            try:
                print(f"Creating new classification head for task {task_name}")
                from src.models import get_classification_head
                # Use args instead of None
                new_head = get_classification_head(args, task_name + "Val")
                new_head = new_head.to(device)
                task_heads[task_name] = new_head
            except Exception as e:
                print(f"Warning: Unable to create classification head for task {task_name}: {e}")
                continue

        # Extract images and labels
        images = torch.stack([sample["image"] for sample in task_samples])
        labels = torch.tensor([sample["label"] for sample in task_samples])

        # Move to device
        images = images.to(device)
        labels = labels.to(device)

        batch_by_task[task_name] = {
            "images": images, 
            "labels": labels,
            "head": task_heads[task_name]
        }

    return batch_by_task


def er_enhanced_finetune(args, train_dataset, starting_model_path, output_path, er_instance=None, use_sabcd=True):
    """
    Fine-tuning function supporting Experience Replay

    Args:
        args: Training parameters
        train_dataset: Training dataset
        starting_model_path: Starting model path
        output_path: Output path
        er_instance: Experience Replay instance
        use_sabcd: Whether to use SABCD optimizer
    """
    if er_instance is None:
        # If no memory buffer, use original fine-tuning methods
        if use_sabcd:
            return sabcd_finetune(args, train_dataset, starting_model_path, output_path)
        else:
            return continual_finetune(args, train_dataset, starting_model_path, output_path)
    else:
        print(f"Using ER-enhanced fine-tuning method")

        # Load starting model
        model = ImageEncoder(args.model)
        model.load_state_dict(torch.load(
            starting_model_path, map_location=args.device))
        model = model.to(args.device)

        # Create full model (including classification head)
        classification_head = get_classification_head(args, train_dataset)
        full_model = ImageClassifier(model, classification_head)

        # Get current task name (remove Val suffix)
        current_task_name = train_dataset.replace("Val", "")

        # Register current task's classification head
        er_instance.task_heads[current_task_name] = classification_head
        print(f"Registered classification head for current task {current_task_name}")

        full_model.freeze_head()
        full_model = full_model.to(args.device)

        # Execute ER-enhanced fine-tuning
        return er_finetune_with_replay(
            args, train_dataset, full_model, output_path,
            er_instance.memory_buffer, er_instance.task_heads, use_sabcd
        )


def er_finetune_with_replay(args, train_dataset, model, output_path, memory_buffer, task_heads, use_sabcd=False):
    """
    Fine-tuning method combining Experience Replay, supporting multi-task classification heads
    """
    print(f"Starting ER-enhanced fine-tuning training")

    # Get current task name (remove Val suffix)
    current_task_name = train_dataset.replace("Val", "")
    print(f"Current task: {current_task_name}")

    # Prepare current task data
    preprocess_fn = model.train_preprocess
    dataset = get_dataset(
        train_dataset,
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size if hasattr(args, 'batch_size') else 64,
    )
    data_loader = get_dataloader(
        dataset, is_train=True, args=args, image_encoder=None)
    num_batches = len(dataset.train_loader)
    current_loader = data_loader

    # Collect current task samples into buffer
    collect_samples_for_memory(memory_buffer, current_task_name, current_loader,
                               args.device, max_samples=memory_buffer.max_samples_per_task)

    # Use cross-entropy loss function
    loss_fn = F.cross_entropy

    # Set optimizer
    lr = args.lr if hasattr(args, 'lr') else 1e-5
    wd = args.wd if hasattr(args, 'wd') else 0.1
    params = [p for p in model.parameters() if p.requires_grad]

    if use_sabcd:
        from src.optimizers.sabcd import SABCD
        optimizer = SABCD(params, lr=lr, weight_decay=wd)
        print("Using SABCD optimizer for fine-tuning")
    else:
        optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
        print("Using AdamW optimizer for fine-tuning")

    # Learning rate scheduler
    warmup_length = args.warmup_length if hasattr(
        args, 'warmup_length') else 0.1
    num_grad_accumulation = args.num_grad_accumulation if hasattr(
        args, 'num_grad_accumulation') else 2

    # Dataset-specific epochs mapping
    epochs_map = {
        "Cars": 35, "DTD": 76, "EuroSAT": 12, "GTSRB": 11,
        "MNIST": 5, "RESISC45": 15, "SUN397": 14, "SVHN": 4,
        "CIFAR10": 6, "CIFAR100": 6, "STL10": 60, "Food101": 4,
        "Flowers102": 147, "FER2013": 10, "PCAM": 1, "OxfordIIITPet": 82,
        "RenderedSST2": 39, "EMNIST": 2, "FashionMNIST": 5, "KMNIST": 5,
    }

    base_name = train_dataset.replace("Val", "")
    epochs = epochs_map.get(base_name, 10)
    print(f"Setting training epochs to {epochs} for dataset {base_name}")

    scheduler = cosine_lr(
        optimizer, lr, warmup_length, epochs * num_batches // num_grad_accumulation,
    )

    # Training loop
    best_model = None
    best_accuracy = -1.0

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_current_loss = 0.0
        epoch_replay_loss = 0.0
        num_epoch_batches = 0

        # Create iterator for current task data
        current_iter = iter(current_loader)

        # Calculate number of batches to process per epoch
        total_batches = len(current_loader)

        for i in range(total_batches):
            step = (i // num_grad_accumulation + epoch *
                    total_batches // num_grad_accumulation)

            try:
                # Get current task batch data
                try:
                    current_batch = next(current_iter)
                except StopIteration:
                    current_iter = iter(current_loader)
                    current_batch = next(current_iter)

                current_batch = maybe_dictionarize(current_batch)
                current_inputs = current_batch["images"].to(args.device)
                current_labels = current_batch["labels"].to(args.device)

                # Forward propagation for current task data
                current_outputs = model(current_inputs)
                current_loss = loss_fn(current_outputs, current_labels)

                # Initialize total loss as current task loss
                total_loss = current_loss
                replay_loss = 0.0

                # If memory buffer has samples from other tasks, add replay loss
                excluded_tasks = [current_task_name]
                replay_samples = memory_buffer.get_samples(
                    excluded_tasks=excluded_tasks, max_samples=4)  # Limit replay samples

                if replay_samples:
                    try:
                        # Process replay samples by task
                        batch_by_task = process_replay_batch_by_task(
                            replay_samples, args.device, task_heads, model.image_encoder, args=args, max_samples=8
                        )

                        # Calculate replay loss for each task
                        task_losses = []
                        for task_name, batch in batch_by_task.items():
                            # Use task-specific classification head
                            task_head = batch["head"]

                            # Extract images and labels
                            images = batch["images"]
                            labels = batch["labels"]

                            # Forward propagation (separate encoder and classification head)
                            features = model.image_encoder(images)
                            outputs = task_head(features)

                            # Calculate loss
                            task_loss = loss_fn(outputs, labels)
                            task_losses.append(task_loss)

                        # If there are task losses, calculate average
                        if task_losses:
                            replay_loss = torch.sum(torch.stack(task_losses))
                            # Add replay loss to total loss
                            total_loss = current_loss + 10 * replay_loss
                    except Exception as e:
                        print(f"Error processing replay samples: {str(e)}")
                        # On error, use only current task loss
                        total_loss = current_loss
                        replay_loss = 0.0

                # Backpropagation
                try:
                    total_loss.backward()

                    if (i + 1) % num_grad_accumulation == 0:
                        scheduler(step)
                        torch.nn.utils.clip_grad_norm_(params, 1.0)
                        optimizer.step()
                        optimizer.zero_grad()
                except RuntimeError as e:
                    print(f"Backpropagation error: {str(e)}")
                    optimizer.zero_grad()
                    continue

                # Accumulate losses
                epoch_loss += total_loss.item()
                epoch_current_loss += current_loss.item()
                if isinstance(replay_loss, torch.Tensor):
                    epoch_replay_loss += replay_loss.item()
                num_epoch_batches += 1

                if i % 10 == 0:
                    print(f"Epoch {epoch+1}/{epochs}, Batch {i+1}/{total_batches}, "
                          f"Current Loss: {current_loss.item():.6f}, "
                          f"Replay Loss: {replay_loss if isinstance(replay_loss, float) else replay_loss.item():.6f}, "
                          f"Total Loss: {total_loss.item():.6f}", end="\r")

                # Clean GPU memory periodically
                if i % 10 == 0:
                    torch.cuda.empty_cache()

            except Exception as e:
                print(f"Error processing batch: {str(e)}")
                import traceback
                traceback.print_exc()
                optimizer.zero_grad()
                continue

        # Calculate average losses
        avg_loss = epoch_loss / max(num_epoch_batches, 1)
        avg_current_loss = epoch_current_loss / max(num_epoch_batches, 1)
        avg_replay_loss = epoch_replay_loss / max(num_epoch_batches, 1)
        print(f"\nEpoch {epoch+1} average losses: {avg_loss:.6f}, "
              f"Current task: {avg_current_loss:.6f}, Replay: {avg_replay_loss:.6f}")

        # Evaluate model (every 5 epochs or last epoch)
        if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
            acc = eval_single_dataset(
                model.image_encoder, train_dataset, args)['top1']
            print(f"Epoch {epoch+1} validation accuracy: {acc*100:.2f}%")

            if acc > best_accuracy:
                best_accuracy = acc
                best_model = model.image_encoder.state_dict().copy()
                print(f"Better model found, accuracy: {best_accuracy*100:.2f}%")

    # Save best model
    if best_model is not None:
        torch.save(best_model, output_path)
        print(f"Saved best model (accuracy: {best_accuracy*100:.2f}%) to {output_path}")

    print(f"ER fine-tuning completed, best accuracy: {best_accuracy*100:.2f}%")
    return best_accuracy