import json
import os
from collections import OrderedDict

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from .abide import ABIDE
from .celeba import CelebA
from .const import GTSRB_LABEL_MAP
from .dataset_json import CUB200Dataset, NabirdsDataset, StanfordDogs, OxfordFlowers
from .dataset_lmdb import CLSWSLMDBDataset
from .dataset_vtab_1k import get_data
from .waterbirds import Waterbirds


def refine_classnames(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names


def get_class_names_from_split(root):
    with open(os.path.join(root, "split.json")) as f:
        split = json.load(f)["test"]
    idx_to_class = OrderedDict(sorted({s[-2]: s[-1] for s in split}.items()))
    return list(idx_to_class.values())


def prepare_data_loader(dataset, data_path, batch_size=128, image_resize=224, per_class_size=None,
                        target_attr="Smiling"):
    data_path = os.path.join(data_path, dataset.replace("-", "/")) \
                if dataset.startswith("vtab1k") \
                else os.path.join(data_path, dataset)
    if dataset.startswith("vtab1k"):
        # vtab1k-caltech101  vtab1k-clevr_count  vtab1k-diabetic_retinopathy  vtab1k-dsprites_loc  vtab1k-dtd      vtab1k-kitti		 vtab1k-oxford_iiit_pet  vtab1k-resisc45	 vtab1k-smallnorb_ele	vtab1k-svhn vtab1k-cifar	    vtab1k-clevr_dist	 vtab1k-dmlab		       vtab1k-dsprites_ori  vtab1k-eurosat  vtab1k-oxford_flowers102  vtab1k-patch_camelyon   vtab1k-smallnorb_azi  vtab1k-sun397
        train_loader, test_loader = get_data(data_path, batch_size=batch_size, evaluate=True)
        loaders = {
            'train': train_loader,
            'test': test_loader,
        }
        configs = {
            'class_names': [f'{i}' for i in range(VTAB_CLASS_NUM[dataset])],
            'mask': np.zeros((image_resize, image_resize)),
        }
    elif dataset == "cifar10":
        preprocess_train = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop((image_resize, image_resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        preprocess_test = transforms.Compose([
            transforms.Resize((image_resize, image_resize)),
            transforms.ToTensor(),
        ])
        train_data = datasets.CIFAR10(root=data_path, train=True, download=True, transform=preprocess_train)
        test_data = datasets.CIFAR10(root=data_path, train=False, download=True, transform=preprocess_test)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    elif dataset == "cifar100":
        preprocess = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop((image_resize, image_resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        train_data = datasets.CIFAR100(root=data_path, train=True, download=True, transform=preprocess)
        test_data = datasets.CIFAR100(root=data_path, train=False, download=True, transform=preprocess)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    elif dataset == "gtsrb":
        preprocess = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop((image_resize, image_resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        train_data = datasets.GTSRB(root=data_path, split="train", download=True, transform=preprocess)
        test_data = datasets.GTSRB(root=data_path, split="test", download=True, transform=preprocess)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(list(GTSRB_LABEL_MAP.values())),
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    elif dataset == "svhn":
        test_transforms = transforms.Compose([
            transforms.Resize((image_resize, image_resize)),
            transforms.ToTensor(),
        ])
        train_transforms = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop((image_resize, image_resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        train_data = datasets.SVHN(root=data_path, split="train", download=True, transform=train_transforms)
        test_data = datasets.SVHN(root=data_path, split="test", download=True, transform=test_transforms)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': [f'{i}' for i in range(10)],
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    elif dataset == "abide":
        train_preprocess = transforms.Compose([
            transforms.ToTensor(),
        ])

        preprocess = transforms.Compose([
            transforms.ToTensor(),
        ])
        D = ABIDE(root=data_path)
        idx = torch.randperm(len(D.data))[:int(len(D.data) * 0.9)]
        X_train = D.data[idx]
        y_train = D.targets[idx]
        X_test = D.data[int(len(D.data) * 0.9):]
        y_test = D.targets[int(len(D.data) * 0.9):]

        train_data = ABIDE(root=data_path, transform=train_preprocess)
        train_data.data = X_train
        train_data.targets = y_train
        test_data = ABIDE(root=data_path, transform=preprocess)
        test_data.data = X_test
        test_data.targets = y_test
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': ["non ASD", "ASD"],
            'mask': D.get_mask(),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    elif dataset in ["food101", "eurosat", "sun397", "ucf101", "stanfordcars", "flowers102", "dtd", "oxfordpets",
                     "caltech101"]:
        train_preprocess = transforms.Compose([
            transforms.Lambda(lambda x: x.convert("RGB")),
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop((image_resize, image_resize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        preprocess = transforms.Compose([
            transforms.Lambda(lambda x: x.convert("RGB")),
            transforms.Resize((image_resize, image_resize)),
            transforms.ToTensor(),
        ])
        train_data = CLSWSLMDBDataset(root=data_path, split="train", transform=train_preprocess,
                                      per_class_size=per_class_size)
        test_data = CLSWSLMDBDataset(root=data_path, split="test", transform=preprocess)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=8),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=8),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")
    elif dataset == "celeba":
        transform_train = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop(image_resize),
            transforms.RandomHorizontalFlip(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize((image_resize, image_resize)),
            transforms.CenterCrop(image_resize),
        ])

        train_set = CelebA(data_path, target_attr, domain_attrs=None, img_transform=transform_train, type="train")
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        test_set = CelebA(data_path, target_attr, domain_attrs=None, img_transform=transform_test, type="test")
        test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        loaders = {
            'train': train_loader,
            'test': test_loader
        }
        configs = {
            'class_names': {"Smiling", "Not Smiling"},
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_set)}, test size {len(test_set)}")
    elif dataset == "waterbirds":
        transform_train = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop(image_resize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        transform_test = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.CenterCrop(image_resize),
            transforms.ToTensor(),
        ])
        train_set = Waterbirds(root=data_path, train=True, download=True, transform=transform_train)
        train_loader = DataLoader(train_set, batch_size, shuffle=True, num_workers=2, pin_memory=True)
        test_set = Waterbirds(root=data_path, train=False, download=True, transform=transform_test)
        test_loader = DataLoader(test_set, batch_size, shuffle=False, num_workers=2, pin_memory=True)
        loaders = {
            'train': train_loader,
            'test': test_loader
        }
        configs = {
            'class_names': {"Waterbirds", "Landbirds"},
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_set)}, test size {len(test_set)}")
    elif dataset == "pcam":
        train_preprocess = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop(image_resize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        test_preprocess = transforms.Compose([
            transforms.Resize((image_resize, image_resize)),
            transforms.ToTensor(),
        ])
        train_data = datasets.PCAM(root=data_path, split="train", download=True, transform=train_preprocess)
        test_data = datasets.PCAM(root=data_path, split="test", download=True, transform=test_preprocess)
        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }
        configs = {
            'class_names': [f'{i}' for i in range(2)],
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")
    elif dataset in ["cub200", "nabirds", "stanforddogs", "oxfordflowers"]:
        train_preprocess = transforms.Compose([
            transforms.Resize((int(image_resize * 256 / 224), int(image_resize * 256 / 224))),
            transforms.RandomCrop(image_resize),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        test_preprocess = transforms.Compose([
            transforms.Resize((image_resize, image_resize)),
            transforms.ToTensor(),
        ])

        dataset_dict = {
            "cub200": CUB200Dataset,
            "nabirds": NabirdsDataset,
            "stanforddogs": StanfordDogs,
            "oxfordflowers": OxfordFlowers,
        }

        num_classes = {
            "cub200": 200,
            "nabirds": 55,
            "stanforddogs": 120,
            "oxfordflowers": 102,
        }

        train_data = dataset_dict[dataset](data_dir=data_path, split="train", transforms=train_preprocess)
        test_data = dataset_dict[dataset](data_dir=data_path, split="test", transforms=test_preprocess)

        loaders = {
            'train': DataLoader(train_data, batch_size, shuffle=True, num_workers=2),
            'test': DataLoader(test_data, batch_size, shuffle=False, num_workers=2),
        }

        configs = {
            'class_names': [f'{i}' for i in range(num_classes[dataset])],
            'mask': np.zeros((image_resize, image_resize)),
        }
        print(f"Dataset loaded: train size {len(train_data)}, test size {len(test_data)}")

    else:
        raise NotImplementedError(f"{dataset} not supported")
    return loaders, configs


VTAB_CLASS_NUM = {
    "vtab1k-caltech101": 102,
    "vtab1k-clevr_count": 8,
    "vtab1k-diabetic_retinopathy": 5,
    "vtab1k-dsprites_loc": 16,
    "vtab1k-dtd": 47,
    "vtab1k-kitti": 4,
    "vtab1k-oxford_iiit_pet": 37,
    "vtab1k-resisc45": 45,
    "vtab1k-smallnorb_ele": 9,
    "vtab1k-svhn": 10,
    "vtab1k-cifar": 100,
    "vtab1k-clevr_dist": 6,
    "vtab1k-dmlab": 6,
    "vtab1k-dsprites_ori": 16,
    "vtab1k-eurosat": 10,
    "vtab1k-oxford_flowers102": 102,
    "vtab1k-patch_camelyon": 2,
    "vtab1k-smallnorb_azi": 18,
    "vtab1k-sun397": 397
}
