import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import random
import PIL
from sklearn.model_selection import train_test_split, StratifiedKFold
from timm.data import create_transform
from util.crop import RandomResizedCrop
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import torchvision.transforms as T
import pickle
import copy

def find_dataset_mean_std(dataset="cifar10"):
    available_datasets = ["cifar10", "cifar100", "food101", "imagenet", "mini", "imagenet32", "road_sign", "inat"]
    assert dataset in available_datasets, "Codes are not available for the selected dataset"
    mean = None
    std = None
    if dataset == "cifar10":
        mean = [0.491, 0.482, 0.446]
        std = [0.247, 0.243, 0.261]
    elif dataset == "cifar100":
        mean = [0.507, 0.487, 0.441]
        std = [0.262, 0.251, 0.271]
    elif dataset == "food101":
        mean = [0.550, 0.445, 0.340]
        std = [0.266, 0.269, 0.274]
    elif dataset == "imagenet":
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    elif dataset == "mini":
        mean = [0.471, 0.450, 0.404]
        std = [0.269, 0.261, 0.277]
    elif dataset == "imagenet32":
        mean = [0.481, 0.457, 0.408]
        std = [0.252, 0.245, 0.261]
    elif dataset == "road_sign":
        mean = [0.563, 0.567, 0.604]
        std = [0.338, 0.281, 0.345]
    elif dataset == "inat":
        mean = [0.467, 0.482, 0.376]
        std = [0.234, 0.225, 0.243]    
    return mean, std


def init_pretrain_transform(mean, std):
    train_transform = transforms.Compose([ 
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0), interpolation=3),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
        ])

    val_transform = transforms.Compose([    
          transforms.ToTensor(),
          transforms.Normalize(mean, std)
        ])

    return train_transform, val_transform

def check_dataset_valid(dataset):
    available_datasets = ["cifar10", "cifar100", "food101", "imagenet", "mini", "imagenet32", "road_sign", "inat"]
    assert dataset in available_datasets, "Codes are not available for the selected dataset"

def init_dataset(train_transform, val_transform, dataset="cifar10"):
    check_dataset_valid(dataset)
    train_dataset = None
    val_dataset = None
    num_classes = 0
    if dataset == "cifar10":
        train_dataset = torchvision.datasets.CIFAR10(root='./data', 
                                            train=True,
                                            download=True, 
                                            transform=train_transform)
        val_dataset = torchvision.datasets.CIFAR10(root='./data', 
                                            train=False,
                                            download=True, 
                                            transform=val_transform)
        num_classes = 10
    elif dataset == "cifar100":
        train_dataset = torchvision.datasets.CIFAR100(root='./data', 
                                            train=True,
                                            download=True, 
                                            transform=train_transform)
        val_dataset = torchvision.datasets.CIFAR100(root='./data', 
                                            train=False,
                                            download=True, 
                                            transform=val_transform)
        num_classes = 100
    elif dataset == "food101":
        train_dataset = datasets.ImageFolder(root="./data/food-101/train", transform=train_transform)
        val_dataset = datasets.ImageFolder(root="./data/food-101/valid", transform=val_transform)
        num_classes = 101
    elif dataset == "imagenet":
        train_dataset = datasets.ImageFolder(root="../../../imagenet/train", transform=train_transform)
        val_dataset = datasets.ImageFolder(root="../../../imagenet/val", transform=val_transform)
        num_classes = 1000
    elif dataset == "mini":
        train_dataset = datasets.ImageFolder(root="../../../imagenet/mini_imagenet/train", transform=train_transform)
        num_classes = 100
    elif dataset == "road_sign":
        train_dataset = datasets.ImageFolder(root='../road_sign/train', 
                                            transform=train_transform)
        val_dataset = datasets.ImageFolder(root='../road_sign/test', 
                                           transform=val_transform)
        num_classes = 8
    elif dataset == "inat":
        train_dataset = datasets.ImageFolder(root='../../../iNat2021/train_mini', 
                                            transform=train_transform)
        val_dataset = datasets.ImageFolder(root='../../../iNat2021/val', 
                                           transform=val_transform)
        num_classes = 10000
    return train_dataset, val_dataset, num_classes

def divide_dataset(train_dataset, ratio, dataset):
    available_datasets = ["cifar10", "cifar100", "food101", "imagenet", "mini", "imagenet32", "road_sign", "inat"]
    assert dataset in available_datasets, "Codes are not available for the selected dataset"
    train_size = len(train_dataset)
    labels = []
    if dataset == "cifar10" or dataset == "cifar100":
        labels = train_dataset.targets
    elif dataset == "food101" or dataset == "imagenet" or dataset == "mini" or dataset == "road_sign" or dataset == "inat":
        imgs = np.array(train_dataset.imgs)
        labels = imgs[:,1]
    else:
        labels = train_dataset.train_labels
    super_train_idxs, unsuper_train_idxs, _, _ = train_test_split(
        range(train_size),
        labels,
        stratify=labels,
        train_size=ratio,
        random_state=0
    )  
    return super_train_idxs, unsuper_train_idxs

def get_super_loaders(train_dataset, val_dataset, super_train_idxs, batch_size):

    super_train_dataset = torch.utils.data.Subset(train_dataset, super_train_idxs)

    supervised_dataloaders = {}
    
    supervised_dataloaders['train'] = torch.utils.data.DataLoader(super_train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    supervised_dataloaders['val'] = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    return supervised_dataloaders

def get_unsuper_loaders(train_dataset, val_dataset, unsuper_train_idxs, clientIDlist, batch_size):

    unsupervised_dataloaders = {}
    step = 0
    train_step_size = int(len(unsuper_train_idxs) / len(clientIDlist))
    valloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    for clientID in clientIDlist:
        unsupervised_dataloaders[clientID] = {}
        c_train_idxs = unsuper_train_idxs[step*train_step_size : (step+1)*train_step_size]
        c_trainset = torch.utils.data.Subset(train_dataset, c_train_idxs)
        c_trainloader = torch.utils.data.DataLoader(c_trainset, batch_size=batch_size, shuffle=True, num_workers=0)
        unsupervised_dataloaders[clientID]['train'] = c_trainloader
        unsupervised_dataloaders[clientID]['val'] = valloader
        step += 1

    return unsupervised_dataloaders

def get_unsuper_datasets(train_dataset, clientIDlist,  sampling="iid", alpha=0.1, dataset="cifar10", num_classes=10):
    check_dataset_valid(dataset)
    unsupervised_datasets = {}
    num_clients = len(clientIDlist)
    train_size = len(train_dataset)
    if dataset == "cifar10" or dataset == "cifar100":
        all_ids_train = np.array(train_dataset.targets)
    elif dataset == "food101" or dataset == "imagenet" or dataset == "mini" or dataset == "road_sign" or dataset == "inat":
        imgs = np.array(train_dataset.imgs)
        all_ids_train = [int(imgs[i, 1]) for i in range(imgs.shape[0])]
        all_ids_train = np.array(all_ids_train)
    else:
        all_ids_train = np.array(train_dataset.train_labels)

    if sampling == "iid":
        skf = StratifiedKFold(n_splits=num_clients, shuffle=True)
        for i, (_, c_idxs) in enumerate(skf.split(np.zeros(train_size), all_ids_train)):
            clientID = clientIDlist[i]
            c_trainset = torch.utils.data.Subset(train_dataset, c_idxs)
            unsupervised_datasets[clientID] = c_trainset
    else:
        class_ids_train = {class_num: np.where(all_ids_train == class_num)[0] for class_num in range(num_classes)}
        dist_of_client = np.random.dirichlet(np.repeat(alpha, num_clients), size=num_classes).transpose()
        dist_of_client /= dist_of_client.sum()

        for i in range(100):
            s0 = dist_of_client.sum(axis=0, keepdims=True)
            s1 = dist_of_client.sum(axis=1, keepdims=True)
            dist_of_client /= s0
            dist_of_client /= s1
        
        samples_per_class_train = (np.floor(dist_of_client * train_size))

        start_ids_train = np.zeros((num_clients + 1, num_classes), dtype=np.int32)
        for i in range(0, num_clients):
            start_ids_train[i+1] = start_ids_train[i] + samples_per_class_train[i]

        for client_num in range(num_clients):
            l = np.array([], dtype=np.int32)
            for class_num in range(num_classes):
                start, end = start_ids_train[client_num, class_num], start_ids_train[client_num + 1, class_num]
                l = np.concatenate((l, class_ids_train[class_num][start:end].tolist())).astype(np.int32)
            c_trainset = torch.utils.data.Subset(train_dataset, l)
            client_ID = clientIDlist[client_num]
            unsupervised_datasets[client_ID] = c_trainset

    return unsupervised_datasets

"""
def get_varied_unsuper_datasets(train_dataset, clientIDlist, dataset="cifar10", num_classes=10):
    # Assign varied size of train data to each client in network
    check_dataset_valid(dataset)
    unsupervised_datasets = {}
    client_data_ids = {}
    num_clients = len(clientIDlist)
    train_size = len(train_dataset)
    all_ids_train = [i for i in range(train_size)]
    random.shuffle(all_ids_train)
    groups_train = {}
    num_groups = int(train_size / num_classes)
    for i in range(num_groups):
        groups_train[i] = all_ids_train[i*num_classes : (i+1)*num_classes]
    if train_size % num_classes != 0:
        groups_train[num_groups] = all_ids_train[num_groups*num_classes :]
    remained_groups = list(groups_train.keys())
    for i in range(num_clients):
        cid = clientIDlist[i]
        num_max_allowed_groups = len(remained_groups) - (num_clients-(i+1))
        pick_num = random.randint(1, num_max_allowed_groups)
        pick_group_options = random.sample(remained_groups, pick_num)
        pick_train_ids = []
        for gid in pick_group_options:
            pick_train_ids.append(groups_train[gid])
            remained_groups.remove(gid)
        c_train_ids = sum(pick_train_ids, [])
        assert len(c_train_ids) > 0, "each node in network should be assigned with some data"
        c_trainset = torch.utils.data.Subset(train_dataset, c_train_ids)
        unsupervised_datasets[cid] = c_trainset
        client_data_ids[cid] = c_train_ids
    return unsupervised_datasets, client_data_ids
"""

def increment_elements(matrix, num):
    rows, cols = matrix.shape

    while num > 0:
        for i in range(rows):
            for j in range(cols):
                if matrix[i, j] != 0:
                    matrix[i, j] += 1
                    num -= 1

                    if num == 0:
                        return matrix

    return matrix

def get_varied_unsuper_datasets(train_dataset, clientIDlist, dataset="cifar10", num_classes=10, alpha=0.1):
    check_dataset_valid(dataset)
    unsupervised_datasets = {}
    num_clients = len(clientIDlist)
    train_size = len(train_dataset)
    if dataset == "cifar10" or dataset == "cifar100":
        all_ids_train = np.array(train_dataset.targets)
    elif dataset == "food101" or dataset == "imagenet" or dataset == "mini" or dataset == "road_sign" or dataset == "inat":
        imgs = np.array(train_dataset.imgs)
        all_ids_train = [int(imgs[i, 1]) for i in range(imgs.shape[0])]
        all_ids_train = np.array(all_ids_train)
    else:
        all_ids_train = np.array(train_dataset.train_labels)

    class_ids_train = {class_num: np.where(all_ids_train == class_num)[0] for class_num in range(num_classes)}
    dist_of_client = np.random.dirichlet(np.repeat(alpha, num_clients), size=num_classes).transpose()
    dist_of_client /= dist_of_client.sum()

    samples_per_class_train = np.floor(dist_of_client * train_size)
    remain_num = train_size - samples_per_class_train.sum()
    samples_per_class_train = increment_elements(samples_per_class_train, remain_num)

    start_ids_train = np.zeros((num_clients + 1, num_classes), dtype=np.int32)
    for i in range(0, num_clients):
        start_ids_train[i+1] = start_ids_train[i] + samples_per_class_train[i]

    client_data_ids = {}
    remained_class_num = {}
    next_client_existed = {}
    for class_num in range(num_classes):
        remained_class_num[class_num] = len(class_ids_train[class_num])
        next_client_existed[class_num] = True
    for client_num in range(num_clients):
        cid = clientIDlist[client_num]
        l = np.array([], dtype=np.int32)
        for class_num in range(num_classes):
            start, end = start_ids_train[client_num, class_num], start_ids_train[client_num + 1, class_num]
            
            if next_client_existed[class_num]:
                next_client_existed[class_num] = False
                for remained_client_num in range(client_num + 1, num_clients):
                    new_start, new_end = start_ids_train[remained_client_num, class_num], start_ids_train[remained_client_num + 1, class_num]
                    if new_end - new_start > 0:
                        next_client_existed[class_num] = True
                        break
            if not next_client_existed[class_num] and remained_class_num[class_num] > 0:
                end = len(class_ids_train[class_num])
            l = np.concatenate((l, class_ids_train[class_num][start:end].tolist())).astype(np.int32)
            remained_class_num[class_num] = remained_class_num[class_num] - (end - start)
            
        c_trainset = torch.utils.data.Subset(train_dataset, l)
        unsupervised_datasets[cid] = c_trainset
        client_data_ids[cid] = l
    
    num_total_ids = 0
    for _, ids in client_data_ids.items():
        num_total_ids += len(ids)
    
    assert num_total_ids == train_size, "The total number of data in the data split record, which is %s, should match with the size of the specified dataset, which is %s" % (num_total_ids, train_size)

    return unsupervised_datasets, client_data_ids

    


def save_data_split(client_data_ids, save_path):
    with open(save_path, 'wb') as fp:
        pickle.dump(client_data_ids, fp)
        print('Data split saved successfully to %s' % save_path)

def load_data_split(train_dataset, clientIDlist, dataset, load_path):
    check_dataset_valid(dataset)

    with open(load_path, 'rb') as fp:
        client_data_ids = pickle.load(fp)
    
    assert len(client_data_ids.keys()) == len(clientIDlist), "The number of clients in the data split record should match with the specified number of clients"

    num_total_ids = 0
    for _, ids in client_data_ids.items():
        num_total_ids += len(ids)
    
    assert num_total_ids == len(train_dataset), "The total number of data in the data split record, which is %s, should match with the size of the specified dataset, which is %s" % (num_total_ids, len(train_dataset))

    unsupervised_datasets = {}
    for cid in clientIDlist:
        c_trainset = torch.utils.data.Subset(train_dataset, client_data_ids[cid])
        unsupervised_datasets[cid] = c_trainset
    print("Succesefully loads data split from %s" % load_path)
    return unsupervised_datasets


def get_super_dataset(train_dataset, super_train_idxs):

    super_train_dataset = torch.utils.data.Subset(train_dataset, super_train_idxs)

    return super_train_dataset

def init_finetune_transform(args, mean, std):

    
    train_transform = create_transform(
        input_size=args.input_size,
        is_training=True,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation='bicubic',
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        mean=mean,
        std=std
    )
    
    

    """
    train_transform = T.Compose([
                T.RandomResizedCrop(args.input_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC),
                T.RandomHorizontalFlip(p=0.5),
                T.ToTensor(),
                T.Normalize(mean, std)])
    """
    
    # eval transform
    t = []
    if args.input_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(args.input_size / crop_pct)

    val_transform = transforms.Compose([  
        transforms.Resize(size, interpolation=PIL.Image.BICUBIC),
        transforms.CenterCrop(args.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    return train_transform, val_transform 


def init_linprobe_transform(mean, std):

    train_transform = transforms.Compose([ 
        transforms.RandomResizedCrop(224, interpolation=3),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

    val_transform = transforms.Compose([    
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
        ])

    return train_transform, val_transform


    
        
            


    



