import pickle
import numpy as np
import math
import torch
import os
from collections import OrderedDict
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
from get_dataset import get_dataset, NUM_CLASSES


with open (DATAPATH, 'rb') as f:
    DS_MEAN_STD = pickle.load(f)


def get_transform(
    ds_name, image_size, random_crop=False, random_horizontal_flip=False, normalize_mean=(0.5,), normalize_std=(0.5,)
):
    transform_list = [transforms.Resize((image_size, image_size))]
    if random_crop:
        transform_list.append(transforms.RandomCrop(image_size, padding=(4 if image_size == 32 else 8)))
    if random_horizontal_flip:
        transform_list.append(transforms.RandomHorizontalFlip())
    transform_list.append(transforms.ToTensor())
    normalize_mean = DS_MEAN_STD[ds_name]['mean']
    normalize_std = DS_MEAN_STD[ds_name]['std']
    transform_list.append(transforms.Normalize(normalize_mean, normalize_std))
    return transforms.Compose(transform_list)


def get_ofa_train_transform(image_size, resize_scale=0.08, distort_color='tf', print_log=False):

    if print_log:
        print(f'Color jitter: {distort_color}, resize_scale: {resize_scale}, image_size: {image_size}') 

    # random_resize_crop -> random_horizontal_flip
    train_transforms = [
        transforms.RandomResizedCrop(image_size, scale=(resize_scale, 1.0)),
        transforms.RandomHorizontalFlip(),
    ]

    # color augmentation (optional)
    if distort_color == 'torch':
        train_transforms.append(
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1))
    elif distort_color == 'tf':
        train_transforms.append(
            transforms.ColorJitter(brightness=32. / 255., saturation=0.5))

    train_transforms += [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]

    train_transforms = transforms.Compose(train_transforms)
    return train_transforms


def get_ofa_valid_transform(image_size):
	return transforms.Compose([
			transforms.Resize(int(math.ceil(image_size / 0.875))),
			transforms.CenterCrop(image_size),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
		])


def get_meta_train_dataloader(
    name, image_size, batch_size, default_data_path, use_ofa_transform=True, split=0, total_split=10, num_workers=1):
    # split: variable for distributed learning (rank)

    if name not in ["tiny_imagenet"]:
        raise NotImplementedError
        
    # Get split information
    split_info = np.load(f"{default_data_path}/{name}/{total_split}_split/split_{split}.npz")
    label_list = list(split_info["label_list"])
    idx_train = list(split_info["idx_train"])
    idx_valid = list(split_info["idx_valid"])
    # idx_test = list(split_info["idx_test"])

    def target_transform(y):
        return label_list.index(y)

    if use_ofa_transform:
        transform_train = get_ofa_train_transform(image_size)
        transform_test = get_ofa_valid_transform(image_size)
    else:
        transform_train = get_transform(name, image_size, random_crop=True, random_horizontal_flip=True)
        transform_test = get_transform(name, image_size)
        

    train_ds = get_dataset(name, train=True, default_data_path=default_data_path,
                transform=transform_train, target_transform=target_transform)
    valid_ds = get_dataset(name, train=True, default_data_path=default_data_path,
                transform=transform_test, target_transform=target_transform)
	
    # idx_train, idx_valid = random_sample_valid_set(len(train_ds), valid_size)
    # test_ds = get_dataset(name, train=False, transform=transform_test, target_transform=target_transform)

    kwargs = {"batch_size": batch_size, "num_workers": num_workers, "pin_memory": False, "drop_last": True}
    train_loader = DataLoader(train_ds, sampler=SubsetRandomSampler(idx_train), **kwargs)
    valid_loader = DataLoader(valid_ds, sampler=SubsetRandomSampler(idx_valid), **kwargs)
    test_loader = None
    # test_loader = DataLoader(test_ds, sampler=SubsetRandomSampler(idx_test), **kwargs)

    return train_loader, valid_loader, test_loader, len(label_list)


def get_meta_test_dataloader(name, image_size, batch_size, default_data_path, use_ofa_transform=True, split=None, num_workers=1):

    num_instances = None

    if use_ofa_transform:
        transform_train = get_ofa_train_transform(image_size)
        transform_test = get_ofa_valid_transform(image_size)
    else:
        transform_train = get_transform(name, image_size, random_crop=True, random_horizontal_flip=True)
        transform_test = get_transform(name, image_size)
        

    train_ds = get_dataset(name, train=True, default_data_path=default_data_path,
                transform=transform_train)
    test_ds = get_dataset(name, train=False, default_data_path=default_data_path,
                transform=transform_test)

    kwargs = {"batch_size": batch_size, "num_workers": num_workers, "pin_memory": True, "drop_last": True}
    if num_instances:
        train_idx = []
        for c in range(NUM_CLASSES[name]):
            try:
                train_idx.extend(list(np.argwhere(train_ds.labels == c)[:num_instances, 0]))
            except AttributeError:
                train_idx.extend(list(np.argwhere(np.array(train_ds.targets) == c)[:50, 0]))
        train_loader = DataLoader(train_ds, sampler=SubsetRandomSampler(train_idx), **kwargs)
        test_loader = DataLoader(test_ds, **kwargs)
    else:
        train_loader = DataLoader(train_ds, shuffle=True, **kwargs)
        test_loader = DataLoader(test_ds, **kwargs)
    return train_loader, test_loader, None, NUM_CLASSES[name]


def get_dataloader(default_data_path, mode, image_size, batch_size, 
                    ds_name, ds_split, mtrn_hetero_on, mtst_subset_on,
                    class_split_ratio, instance_split_ratio):
    if mtrn_hetero_on:
        train_loader, valid_loader, _, n_classes = get_dataloader_heterogeneous(
            name=ds_name, 
            mode=mode, 
            image_size=image_size, 
            batch_size=batch_size, 
            default_data_path=default_data_path, 
            class_split_ratio=class_split_ratio,
            instance_split_ratio=instance_split_ratio)
    else:
        if ds_split is not None:
            train_loader, valid_loader, _, n_classes = get_meta_train_dataloader(
                name=ds_name,
                image_size=image_size,
                batch_size=batch_size,
                default_data_path=default_data_path,
                use_ofa_transform=False, 
                split=ds_split)
        else: # Meta-test
            if mtst_subset_on:
                train_loader, valid_loader, _, n_classes = get_dataloader_heterogeneous(
                    name=ds_name, 
                    mode=mode, 
                    image_size=image_size, 
                    batch_size=batch_size, 
                    default_data_path=default_data_path, 
                    class_split_ratio=class_split_ratio,
                    instance_split_ratio=instance_split_ratio)
            else:
                train_loader, valid_loader, _, n_classes = get_meta_test_dataloader(
                    name=ds_name,
                    image_size=image_size,
                    batch_size=batch_size,
                    default_data_path=default_data_path,
                    use_ofa_transform=False, 
                    split=ds_split)
    return train_loader, valid_loader, n_classes

