import numpy as np
import torch
from avalanche.benchmarks.utils import as_classification_dataset
from avalanche.training.storage_policy import ClassBalancedBuffer
from torch.utils.data import ConcatDataset, Subset
from utils.training import evaluate

class CalibrationBuffer:
    def __init__(self, benchmark_name='cifar100', mode='balanced', transform_type='eval'):
        self.buffer = None
        self.confidence_scores = None
        self.mode = mode
        self.benchmark_name = benchmark_name
        self.transform_type = transform_type
        if mode not in ['full', 'balanced', 'inv_proportional', 'original']:
            raise ValueError(f"Unknown mode: {mode}. Use 'full', 'balanced', 'inv_proportional', or 'original'.")

        if self.benchmark_name in ['cifar10']:
            self.fraction = 0.1
        elif self.benchmark_name in ['cifar100']:
            self.fraction = 0.4
        elif self.benchmark_name in ['dermamnist', 'bloodmnist']:
            self.fraction = 0.5
        elif self.benchmark_name in ['tinyimagenet']:
            self.fraction = 0.2
        else:
            raise ValueError(f"Unknown benchmark: {self.benchmark_name}. Supported benchmarks are 'cifar10', 'cifar100', 'tinyimagenet', 'bloodmnist', and 'dermamnist'.")

    def update(self, experience_val):
        current_task_id = experience_val.current_experience
        experience_val_data = as_classification_dataset(experience_val.dataset.with_transforms(self.transform_type))
        buffer_length = len(experience_val_data)
        indices = list(range(buffer_length))
        np.random.shuffle(indices)
        if self.mode == 'full':
            val_split_index = int(buffer_length)
        if self.mode == 'balanced':
            val_split_index = int(np.floor(self.fraction * buffer_length))
            print(buffer_length, val_split_index)
        if self.mode == 'inv_proportional':
            val_split_index = int(buffer_length // (current_task_id + 1))

        # print(f'Task {task_id} - Split index: {val_split_index}')
        self.new_buffer = Subset(experience_val_data, indices[:val_split_index])
        if self.buffer:
            self.buffer = ConcatDataset([self.buffer, self.new_buffer])
        else:
            self.buffer = self.new_buffer

    def update_confidence_scores(self, model, device):
        buffer_logits, buffer_labels = evaluate(model, self.new_buffer, device)
        new_confidence_scores = buffer_logits.softmax(dim=1).max(dim=1)[0]
        if self.confidence_scores is None:
            self.confidence_scores = new_confidence_scores
        else:
            self.confidence_scores = torch.cat([self.confidence_scores, new_confidence_scores])

class CustomClassBalancedBuffer(ClassBalancedBuffer):
    def update(self, strategy, experience=None, **kwargs):
        """
        Update buffer using the dataset from the explicitly provided experience.
        Fallback to strategy.experience if not provided.
        """
        if experience is None:
            experience = strategy.experience  # fallback

        dataset = experience.dataset
        self.update_from_dataset(dataset, strategy)
