import os
import sys

import torch
import numpy as np
from torch.utils.data import DataLoader, ConcatDataset, Subset

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from tqdm import tqdm
import clip

model_name_mappings = {'clipvitB32': 'ViT-B/32', 'clipvitB16': 'ViT-B/16', 'clipvitL14': 'ViT-L/14', 'dinov2': 'dinov2'}


class MyFER2013(datasets.FER2013):
    _RESOURCES = {
        "train": ("train_split.csv", "aa1bdf3e64bc6697783ce586283a2b74"),
        "test": ("test_split.csv", "8576e0f5a806d7b337d6eeda66d71dc0"),
    }

class MySUN397(datasets.SUN397):
    def __init__(self, root, transform, target_transform=None, partition=1, split="train"):
        super().__init__(root=root, transform=transform, target_transform=target_transform, download=False)
        self.partition = partition
        self.split = split
        self.filter()

    def filter(self):
        split_str = f"Training_{self.partition:02d}.txt" if self.split == "train" else f"Testing_{self.partition:02d}.txt"
        with open(self._data_dir / split_str) as f:
            self._image_files = f.read().splitlines()
            self._image_files = [self._data_dir / elem[1:] for elem in self._image_files]

        self._labels = [
            self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
        ]

def get_data_feats(root_dir, dataset_name, model_names, device):

    if not os.path.exists(root_dir):
        os.makedirs(root_dir)

    feat_dir = f"{root_dir}/feats"
    if not os.path.exists(feat_dir):
        os.makedirs(feat_dir)

    batch_size = 200
    train_feat_sets = []
    test_feat_sets = []
    train_labels, test_labels = None, None
    for idx, model_name in enumerate(model_names):
        model_dir = f"{feat_dir}/{model_name}"
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # loading the labels of the training and validation samples. Tip: The labels are loaded only once.
        if idx == 0:
            train_label_file_path = f"{feat_dir}/{dataset_name}_train_labels.npy"
            test_label_file_path = f"{feat_dir}/{dataset_name}_test_labels.npy"
            if not os.path.exists(train_label_file_path):
                train_dataset, test_dataset = get_splitted_datasets(root_dir, dataset_name,
                                                                    model_name_mappings[model_name], device)
                train_labels, test_labels = generate_pretrained_data_labels(train_dataset, test_dataset,
                                                                            train_label_file_path,
                                                                            test_label_file_path)
            else:
                train_labels = np.load(train_label_file_path)
                test_labels = np.load(test_label_file_path)

        # loading the features of the training and validation samples generated by pretrained models
        train_feat_file_path = f"{model_dir}/{dataset_name}_train_feats.npy"
        test_feat_file_path = f"{model_dir}/{dataset_name}_test_feats.npy"
        print(f'{idx}: {model_name_mappings[model_name]}')

        if not os.path.exists(train_feat_file_path):
            model = get_model(root_dir, model_name_mappings[model_name], device)
            train_dataset, test_dataset = get_splitted_datasets(root_dir, dataset_name,
                                                                model_name_mappings[model_name], device)
            train_feats, test_feats = generate_pretrained_data(train_dataset, test_dataset, model, batch_size,
                                                               device, train_feat_file_path, test_feat_file_path)
        else:
            train_feats = np.load(train_feat_file_path)
            test_feats = np.load(test_feat_file_path)
        train_feat_sets.append(torch.from_numpy(train_feats).float().to(device))
        test_feat_sets.append(torch.from_numpy(test_feats).float().to(device))

    return train_feat_sets, test_feat_sets, train_labels, test_labels


@torch.no_grad()
def generate_pretrained_features(data_loader, model, device):
    pretrained_features = []

    data_loader = tqdm(data_loader, file=sys.stdout)
    for x, y in tqdm(data_loader):
        features = model(x.to(device))
        pretrained_features.append(features.detach().cpu())

    return torch.cat(pretrained_features).numpy()


def get_model(root_dir, model_name, device):
    if model_name == 'dinov2':
        chkpt_dir = os.path.join(root_dir, "checkpoints/dinov2")
        torch.hub.set_dir(chkpt_dir)
        model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14', source='github').to(device)
        model.eval()
        print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
    else:
        chkpt_dir = os.path.join(root_dir, "checkpoints/clip")
        model, _ = clip.load(model_name, device=device, download_root=chkpt_dir)
        model.eval()
        print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
        model = model.encode_image

    return model


def get_splitted_datasets(root_dir, dataset_name, model_name, device):
    if model_name == 'dinov2':
        transform = get_default_transforms()
    else:
        chkpt_dir = os.path.join(root_dir, "checkpoints/clip")
        _, transform = clip.load(model_name, device=device, download_root=chkpt_dir)
        transform.transforms[2] = _convert_image_to_rgb
        transform.transforms[3] = _safe_to_tensor

    train_dataset, test_dataset = get_datasets(dataset_name, transform, root_dir)

    return train_dataset, test_dataset


def generate_pretrained_data_labels(train_dataset, test_dataset, train_label_file_path, test_label_file_path):
    train_labels = get_labels(train_dataset)
    test_labels = get_labels(test_dataset)
    np.save(train_label_file_path, train_labels)
    np.save(test_label_file_path, test_labels)

    return train_labels, test_labels,


def generate_pretrained_data(train_dataset, test_dataset, model, batch_size, device,
                             train_feat_file_path, val_feat_file_path):
    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=10)
    valloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=10)
    train_feats = generate_pretrained_features(trainloader, model, device)
    test_feats = generate_pretrained_features(valloader, model, device)

    np.save(train_feat_file_path, train_feats)
    np.save(val_feat_file_path, test_feats)

    return train_feats, test_feats


IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


def get_default_transforms():
    return transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        _convert_image_to_rgb,
        _safe_to_tensor,
        transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
    ])


def get_datasets(dataset, transform, root_dir='./data'):
    data_path = os.path.join(root_dir, "datasets")
    if dataset == "food101":
        train_dataset = datasets.Food101(root=data_path, split="train", transform=transform, download=True)
        test_dataset = datasets.Food101(root=data_path, split="test", transform=transform, download=True)

    elif dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=data_path, train=True, transform=transform, download=True)
        test_dataset = datasets.CIFAR10(root=data_path, train=False, transform=transform, download=True)

    elif dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=data_path, train=True, transform=transform, download=True)
        test_dataset = datasets.CIFAR100(root=data_path, train=False, transform=transform, download=True)

    elif dataset == 'stl10':
        train_dataset = datasets.STL10(root=data_path, split="train", transform=transform, download=True)
        test_dataset = datasets.STL10(root=data_path, split="test", transform=transform, download=True)

    elif dataset == "imagenet":
        train_dataset = datasets.ImageFolder(root=os.path.join(data_path, "imagenet/train"), transform=transform)
        test_dataset = datasets.ImageFolder(root=os.path.join(data_path, "imagenet/val"), transform=transform)

    elif dataset == "aircraft":
        train_dataset = datasets.ImageFolder(root=os.path.join(data_path, "aircraft/trainval"), transform=transform)
        test_dataset = datasets.ImageFolder(root=os.path.join(data_path, "aircraft/test"), transform=transform)

    elif dataset == "caltech101":
        tmp_dataset = datasets.ImageFolder(root=os.path.join(data_path, "caltech-101/101_ObjectCategories"), transform=transform)
        tmp_targets = np.array(tmp_dataset.targets)
        subset = []
        for t in np.unique(tmp_targets):
            np.random.seed(42)
            subset.extend(
                np.random.choice(np.where(tmp_targets == t)[0], size=30, replace=False)
            )
        subset_val = list(set([i for i in range(len(tmp_targets))]) - set(subset))
        train_dataset = Subset(tmp_dataset, subset)
        test_dataset = Subset(tmp_dataset, subset_val)


    elif dataset == "flowers":
        tmp_dataset1 = datasets.Flowers102(root=data_path, split="train", transform=transform, download=True)
        tmp_dataset2 = datasets.Flowers102(root=data_path, split="val", transform=transform, download=True)
        train_dataset = ConcatDataset((tmp_dataset1, tmp_dataset2))
        test_dataset = datasets.Flowers102(root=data_path, split="test", transform=transform, download=True)

    elif dataset == "dtd":
        tmp_dataset1 = datasets.DTD(root=data_path, split="train", transform=transform, download=True)
        tmp_dataset2 = datasets.DTD(root=data_path, split="val", transform=transform, download=True)
        train_dataset = ConcatDataset((tmp_dataset1, tmp_dataset2))
        test_dataset = datasets.DTD(root=data_path, split="test", transform=transform, download=True)

    elif dataset == "pets":
        train_dataset = datasets.ImageFolder(root=os.path.join(data_path, "pets/train"), transform=transform)
        test_dataset = datasets.ImageFolder(root=os.path.join(data_path, "pets/test"), transform=transform)

    elif dataset == "gtsrb":
        train_dataset = datasets.GTSRB(root=data_path, split="train", transform=transform, download=True)
        test_dataset = datasets.GTSRB(root=data_path, split="test", transform=transform, download=True)

    elif dataset == "fer2013":
        train_dataset = MyFER2013(root=data_path, split="train", transform=transform)
        test_dataset = MyFER2013(root=data_path, split="test", transform=transform)

    elif dataset == "eurosat":
        tmp_dataset = datasets.EuroSAT(root=data_path, transform=transform, download=False)
        tmp_targets = np.array(tmp_dataset.targets)
        subset_train = []
        subset_val = []
        for t in np.unique(tmp_targets):
            np.random.seed(42)
            subset = np.random.choice(np.where(tmp_targets == t)[0], size=1500, replace=False)
            subset_train.extend(subset[:1000])
            subset_val.extend(subset[1000:])
        train_dataset = Subset(tmp_dataset, subset_train)
        test_dataset = Subset(tmp_dataset, subset_val)

    elif dataset == "kitti":
        train_dataset = datasets.ImageFolder(root=os.path.join(data_path, "Kitti/train"), transform=transform)
        test_dataset = datasets.ImageFolder(root=os.path.join(data_path, "Kitti/val"), transform=transform)

    elif dataset == "sst":
        tmp_dataset1 = datasets.ImageFolder(root=os.path.join(data_path, "rendered-sst2/train"), transform=transform)
        tmp_dataset2 = datasets.ImageFolder(root=os.path.join(data_path, "rendered-sst2/valid"), transform=transform)
        train_dataset = ConcatDataset((tmp_dataset1, tmp_dataset2))
        test_dataset = datasets.ImageFolder(root=os.path.join(data_path, "rendered-sst2/test"), transform=transform)

    elif dataset == "sun397":
        train_dataset = MySUN397(root=data_path, partition=1, split="train", transform=transform)
        test_dataset = MySUN397(root=data_path, partition=1, split="test", transform=transform)

    return train_dataset, test_dataset


def get_labels(dataset):
    if hasattr(dataset, "targets"):
        return dataset.targets
    elif hasattr(dataset, "labels"):
        return dataset.labels
    elif hasattr(dataset, "_labels"):  # food101 or aircraft
        return dataset._labels
    elif hasattr(dataset, "_samples"):  # cars
        return [elem[1] for elem in dataset._samples]
    else:
        return [dataset[i][1] for i in range(len(dataset))]


def _convert_image_to_rgb(image):
    if torch.is_tensor(image):
        return image
    else:
        return image.convert("RGB")


def _safe_to_tensor(x):
    if torch.is_tensor(x):
        return x
    else:
        return transforms.ToTensor()(x)
