import torch
import numpy as np

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for name, param in model.named_parameters():
        if 'classifier' in name:
            continue
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )
    return trainable_params, all_param


def subset_of(dataset, n_classes):
    """
    Returns a subset of the dataset with n_classes.
    """
    unique_labels = dataset.unique("label")
    labels = unique_labels[:n_classes]
    return dataset.filter(lambda example: example["label"] in labels)


def get_image_classification_data(dataset_name):
    from datasets import load_dataset
    if dataset_name == "flowers":
        train_val_ds = load_dataset("nelorth/oxford-flowers", split="train")
        test_ds = load_dataset("nelorth/oxford-flowers", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "dtd":
        train_val_test_ds = load_dataset("cansa/Describable-Textures-Dataset-DTD", split="train")
        splits = train_val_test_ds.train_test_split(test_size=0.2)
        train_val_ds = splits["train"]
        test_ds = splits["test"]
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "food":
        train_val_ds = load_dataset("food101", split="train", trust_remote_code=True)
        test_ds = load_dataset("food101", split="validation", trust_remote_code=True)
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "pets":
        train_val_ds = load_dataset("timm/oxford-iiit-pet", split="train")
        test_ds = load_dataset("timm/oxford-iiit-pet", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "resisc":
        train_ds = load_dataset("timm/resisc45", split="train")
        val_ds = load_dataset("timm/resisc45", split="validation")
        test_ds = load_dataset("timm/resisc45", split="test")
    elif dataset_name == "eurosat":
        train_ds = load_dataset("timm/eurosat-rgb", split="train")
        val_ds = load_dataset("timm/eurosat-rgb", split="validation")
        test_ds = load_dataset("timm/eurosat-rgb", split="test")
    elif dataset_name == "cars":
        train_val_ds = load_dataset("Multimodal-Fatima/StanfordCars_train", split="train")
        test_ds = load_dataset("Multimodal-Fatima/StanfordCars_test", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "fgvc":
        train_val_ds = load_dataset("Multimodal-Fatima/FGVC_Aircraft_train", split="train")
        test_ds = load_dataset("Multimodal-Fatima/FGVC_Aircraft_test", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "cifar10":
        train_val_ds = load_dataset("Multimodal-Fatima/CIFAR10_train", split="train")
        test_ds = load_dataset("Multimodal-Fatima/CIFAR10_test", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    elif dataset_name == "cifar100":
        train_val_ds = load_dataset("Multimodal-Fatima/CIFAR100_train", split="train")
        test_ds = load_dataset("Multimodal-Fatima/CIFAR100_test", split="test")
        splits = train_val_ds.train_test_split(test_size=0.1)
        train_ds = splits["train"]
        val_ds = splits["test"]
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    return train_ds, val_ds, test_ds


def prepare_image_classification_data(dataset_name, model_name_or_path):

    train_ds, val_ds, test_ds = get_image_classification_data(dataset_name)
    labels = train_ds.features["label"].names
    label2id, id2label = dict(), dict()
    for i, label in enumerate(labels):
        label2id[label] = i
        id2label[i] = label

    from transformers import AutoImageProcessor
    image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, use_fast=True)

    from torchvision.transforms import (
        CenterCrop,
        Compose,
        Normalize,
        RandomHorizontalFlip,
        RandomResizedCrop,
        Resize,
        ToTensor,
    )

    normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
    train_transforms = Compose(
        [
            RandomResizedCrop(image_processor.size["height"]),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

    val_transforms = Compose(
        [
            Resize(image_processor.size["height"]),
            CenterCrop(image_processor.size["height"]),
            ToTensor(),
            normalize,
        ]
    )

    def preprocess_train(example_batch):
        """Apply train_transforms across a batch."""
        example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    def preprocess_val(example_batch):
        """Apply val_transforms across a batch."""
        example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
        return example_batch

    train_ds.set_transform(preprocess_train)
    val_ds.set_transform(preprocess_val)
    test_ds.set_transform(preprocess_val)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["label"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}
    
    return train_ds, val_ds, test_ds, collate_fn, label2id, id2label, image_processor


def image_classification_metric(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return {"accuracy": (predictions == eval_pred.label_ids).mean()}


def prepare_glue(task, model_name_or_path, max_length, padding_mode):
    from transformers import AutoTokenizer
    from datasets import load_dataset
    import evaluate
    import uuid

    if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
        padding_side = "left"
    else:
        padding_side = "right"

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)

    if getattr(tokenizer, "pad_token_id") is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    datasets = load_dataset("glue", task, trust_remote_code=True)
    metric = evaluate.load("glue", task, trust_remote_code=True, experiment_id=uuid.uuid4().hex)

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        if task in ['sst2', 'cola']:
            outputs = tokenizer(examples["sentence"], truncation=True, max_length=max_length)
        elif task == 'qnli':
            outputs = tokenizer(examples["question"], examples["sentence"], truncation=True, max_length=max_length)
        elif task == 'qqp':
            outputs = tokenizer(examples["question1"], examples["question2"], truncation=True, max_length=max_length)
        else:
            outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=max_length)
        return outputs

    if task in ['sst2', 'cola']:
        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=["idx", "sentence"],
        )
    elif task == 'qnli':
        tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "question", "sentence"],
        )
    elif task == 'qqp':
        tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "question1", "question2"],
        )
    else:
        tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
        )

    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        return tokenizer.pad(examples, max_length=max_length, padding=padding_mode, return_tensors="pt")

    def compute_metric(eval_pred):
        predictions = eval_pred.predictions.argmax(axis = -1) if task != "stsb" else eval_pred.predictions
        return metric.compute(predictions=predictions, references=eval_pred.label_ids)

    if task == "stsb":
        metric_name = "pearson"
    elif task == "cola":
        metric_name = "matthews_correlation"
    else:
        metric_name = "accuracy"

    if task in ["cola", "qnli", "rte", "sst2", "stsb"]:
        train_ds, val_ds = tokenized_datasets["train"].train_test_split(test_size=0.05).values()
        test_ds = tokenized_datasets["validation"]
    else:
        train_ds = tokenized_datasets["train"]
        val_ds = tokenized_datasets["validation"]
        test_ds = tokenized_datasets["test"]

    metric_lower_bond_for_pruner = {
        "cola": 0.0,
        "sst2": 0.8,
        "mrpc": 0.75,
        "qnli": 0.8,
        "rte": 0.5,
        "stsb": 0.8,
    }

    batch_size = {
        "cola": 128,
        "sst2": 128,
        "mrpc": 128,
        "qnli": 64,
        "rte": 64,
        "stsb": 128,
    }

    if model_name_or_path == "FacebookAI/roberta-base":
        epoch = {
            "cola": 80,
            "sst2": 40,
            "mrpc": 80,
            "qnli": 40,
            "rte": 80,
            "stsb": 80,
        }
        batch_size = {
            "cola": 128,
            "sst2": 128,
            "mrpc": 128,
            "qnli": 64,
            "rte": 64,
            "stsb": 128,
        }
    elif model_name_or_path == "FacebookAI/roberta-large":
        epoch = {
            "cola": 70,
            "sst2": 10,
            "mrpc": 80,
            "qnli": 30,
            "rte": 60,
            "stsb": 40,
        }
        batch_size = {
            "cola": 128,
            "sst2": 128,
            "mrpc": 128,
            "qnli": 32,
            "rte": 64,
            "stsb": 128,
        }

    ds_related = {
        "train_ds": train_ds,
        "val_ds": val_ds,
        "test_ds": test_ds,
        "collate_fn": collate_fn,
    }

    metric_related = {
        "compute_metric": compute_metric,
        "metric_name": metric_name,
        "metric_lower_bond_for_pruner": metric_lower_bond_for_pruner[task],
    }

    training_hyperparameters = {
        "epoch": epoch[task],
        "batch_size": batch_size[task],
    }

    return ds_related, metric_related, training_hyperparameters, tokenizer