import os
import numpy as np
import random
import torch
from avalanche.models import IncrementalClassifier
from avalanche.training.plugins import EarlyStoppingPlugin, LRSchedulerPlugin
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, Dataset
from utils.models import SlimResNet18, ResNet32

def get_free_gpu_idx():
    """Get the index of the GPU with current lowest memory usage."""
    if not os.path.exists("./output/"):
        os.makedirs("./output/")
    os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Used  > ./output/tmp")
    memory_available = [int(x.split()[2]) for x in open("./output/tmp", "r").readlines()]
    return np.argmin(memory_available)


def set_seed(seed: int):
    """
    Set the random seed for reproducibility.
    
    Args:
        seed (int): The seed value to set for random number generation.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


def load_model(benchmark_name: str, model_name: str = 'slimresnet'):
    """
    Load a training model by its name.
    
    Args:
        benchmark_name (str): The name of the benchmark to load.  Possible values are 'cifar10', 'cifar100'. 'tinyimagenet
        model_name (str): The name of the model to load. Default is 'slimresnet'. Possible values are 'slimresnet', 'resnet32'.
        
    Returns:
        list: Model, optimizer, various evaluation plugins, loss function
    """
    if benchmark_name == 'cifar10':
        num_classes = 10
        input_size = (3, 32, 32)
        nf=20
        lr=0.001
        weight_decay=5e-4
        patience=10
        T_0=3
        T_mult=1
        mem_size = 1000
    elif benchmark_name == 'cifar100':
        num_classes = 100
        input_size = (3, 32, 32)
        nf=32
        lr=0.002
        weight_decay=5e-4
        patience=20
        T_0=5
        T_mult=2
        mem_size = 4000
    elif benchmark_name == 'tinyimagenet':
        num_classes = 200
        input_size = (3, 64, 64)
        nf=64
        lr=0.002
        weight_decay=5e-4
        patience=50
        T_0=10
        T_mult=2
        if model_name == 'slimresnet':
            mem_size = 1000
        else:
            mem_size = 4000
    elif benchmark_name == 'bloodmnist':
        num_classes = 8
        input_size = (3, 28, 28)
        nf=20
        lr=0.001
        weight_decay=5e-4
        patience=20
        T_0=3
        T_mult=1
        mem_size = 200
    elif benchmark_name == 'dermamnist':
        num_classes = 7
        input_size = (3, 64, 64)
        nf=64
        lr=0.002
        weight_decay=5e-4
        patience=50
        T_0=3
        T_mult=1
        mem_size = 200

    else:
        raise ValueError(f"Unknown benchmark name: {benchmark_name}. Use 'cifar10', 'cifar100', 'tinyimagenet', 'bloodmnist', 'tissuemnist', or 'dermamnist'.")

    if model_name == 'slimresnet':
        model = SlimResNet18(nclasses=num_classes, input_size=input_size, nf=nf)
    else:
        model = ResNet32(nclasses=num_classes, input_size=input_size, nf=nf)

    model.output = IncrementalClassifier(model.output.in_features)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    scheduler_plugin = LRSchedulerPlugin(scheduler)
    early_stopping_plugin = EarlyStoppingPlugin(patience=patience, val_stream_name='valid_stream')
    criterion = CrossEntropyLoss()

    return model, optimizer, early_stopping_plugin, scheduler_plugin, criterion, mem_size


def evaluate(model, dataset, device):
    model.eval()
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for input, label, _ in DataLoader(dataset, batch_size=64):
            logits = model(input.to(device))
            logits_list.append(logits)
            labels_list.append(label)
    logits = torch.cat(logits_list, dim=0).detach().cpu()
    labels = torch.cat(labels_list, dim=0).detach().cpu()
    return logits, labels


def compute_metrics(pred_y, true_y, num_bins=10, from_logits=True):
    if from_logits:
        probs = torch.softmax(pred_y, dim=-1)
    else:
        probs = torch.Tensor(pred_y)
    conf, preds = probs.max(dim=-1)
    acc = (preds == true_y).float()
    final_acc = acc.mean()

    bin_boundaries = torch.linspace(0, 1, num_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    bin_cnf = torch.zeros(num_bins)
    bin_acc = torch.zeros(num_bins)
    bin_cnt = torch.zeros(num_bins)
    ece = ece = torch.zeros(1)
    for i in range(num_bins):
        bin_lower, bin_upper = bin_lowers[i], bin_uppers[i]
        in_bin = (conf > bin_lower) & (conf <= bin_upper)
        count_in_bin = in_bin.sum().item()
        if count_in_bin > 0:
            bin_acc[i] = acc[in_bin].mean()
            bin_cnf[i] = conf[in_bin].mean()
            bin_cnt[i] = count_in_bin

            bin_weight = count_in_bin / conf.shape[0] 
            ece += bin_weight * abs(bin_acc[i] - bin_cnf[i])

    return final_acc.item(), ece.item(), bin_acc, bin_cnf, bin_cnt


class AugmentedDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y, _ = self.dataset[idx]
        x = self.transform(x)
        return x, y, _
    

@torch.no_grad()
def extract_features(model, dataset, device):
    out = []
    dataloader = DataLoader(dataset, batch_size=64)
    model.eval()
    for x, y, _ in dataloader:
        x, y = x.to(device), y.to(device)
        features = model._forward_features(x)  # shape: (B, D, 1, 1) or (B, D)

        if features.ndim == 4:
            features = features.squeeze(-1).squeeze(-1)
        out.append(features.cpu())
    out = torch.cat(out, dim=0)

    return out


def compute_nll(dict_results, calibrated=True, from_logits=True):
    nll_history = []
    for run in range(3):
        nll_run = []
        for task in range(len(dict_results['logits'][run])):
            logits, cal_logits, labels = dict_results['logits'][run][task]
            if calibrated == False:
                cal_logits = logits
            if from_logits:
                preds = torch.clamp(cal_logits, -50, 50)
                nll = torch.nn.functional.cross_entropy(preds, labels, reduction='mean')
            else:
                preds = cal_logits
                true_probs = preds[torch.arange(preds.size(0)), labels]
                nll = -(torch.log(true_probs + 1e-12)).mean()
            nll_run.append(nll.item())
        nll_history.append(nll_run)
    avg_nll = torch.Tensor(nll_history).mean(-1).mean().item()
    std_nll = torch.Tensor(nll_history).mean(-1).std().item()
    return avg_nll, std_nll


def load_trained_model(model, dataset, seed, device, task_id):
    checkpoint = torch.load(f"./checkpoints/{model}/{dataset}/seed{seed}/model_task{task_id}.pth", map_location=device)
    classes_so_far = checkpoint["num_classes"]
    if dataset == 'cifar10':
        input_size = (3, 32, 32)
        nf=20
    elif dataset == 'cifar100':
        input_size = (3, 32, 32)
        nf=32
    elif dataset in ['tinyimagenet', 'dermamnist']:
        input_size = (3, 64, 64)
        nf = 64
    elif dataset in ['bloodmnist', 'tissuemnist', 'organsmnist', 'organamnist', 'organcmnist']:
        input_size = (3, 28, 28)
        nf = 20
    if model == 'slimresnet':
        trained_model = SlimResNet18(nclasses=classes_so_far, input_size=input_size, nf=nf)
    else:
        trained_model = ResNet32(nclasses=classes_so_far, input_size=input_size, nf=nf)
    return trained_model, checkpoint, classes_so_far