# adapted from
# https://github.com/VICO-UoE/DatasetCondensation

import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from networks import ConvNet, MLP, Linear


class TensorDataset(Dataset):
    def __init__(self, images, labels): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()

    def __getitem__(self, index):
        return self.images[index], self.labels[index]

    def __len__(self):
        return self.images.shape[0]
    
    
def get_default_model_setting(model):
    if model=='ConvNet':
        net_width, net_depth, net_act, net_norm, net_pooling = 64, 1, 'relu', 'instancenorm', 'avgpooling'
    if model=='ReLU':
        net_width, net_depth, net_act, net_norm, net_pooling = [1000], 1, 'relu', 'None', 'None'
    if model=='Linear':
        net_width, net_depth, net_act, net_norm, net_pooling = 'None', 'None', 'None', 'None', 'None'
    return net_width, net_depth, net_act, net_norm, net_pooling



def get_network(model, channel, num_classes, im_size=(32, 32)):
    #torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net_width, net_depth, net_act, net_norm, net_pooling = get_default_model_setting(model)

    if model == 'ConvNet':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    
    elif model == 'ReLU':
        net = MLP(channel=channel, num_classes=num_classes, net_width=net_width, net_act=net_act)
        
    elif model == 'Linear':
        net = Linear(channel=channel, num_classes=num_classes, im_size=im_size)
        
    else:
        net = None
        exit('DC error: unknown model')

    """if dist:
        gpu_num = torch.cuda.device_count()
        if gpu_num>0:
            device = 'cuda'
            if gpu_num>1:
                net = nn.DataParallel(net)
        else:
            device = 'cpu'
        net = net.to(device)"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = net.to(device)

    return net



def get_dataset(dataset, data_path, batch_size=1, args=None):

    class_map = None
    loader_train_dict = None
    class_map_inv = None

    if dataset == 'CIFAR10':
        channel = 3
        im_size = (32, 32)
        num_classes = 10
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
        # data preprocessing
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        
        dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
        dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
        #class_names = dst_train.classes
        #class_map = {x:x for x in range(num_classes)}

    else:
        exit('unknown dataset: %s'%dataset)

    testloader = torch.utils.data.DataLoader(dst_test, batch_size=128, shuffle=False, num_workers=2)


    return channel, im_size, num_classes, mean, std, dst_train, dst_test, testloader



def get_tasks_dataset(dst_train, num_task, data_per_task, framework, num_classes):
    train_datasets = {x:{0:[],1:[]} for x in range(num_task)}
    total_datasets = {0:[],1:[]}
    
    if framework=="TIL_naive":
        total_images=[]
        total_labels=[]
        cur = 0
        for i in range(num_task):
            images_all = []
            labels_all = []
            class_dist = {x:0 for x in range(num_classes)}
            k=0
            while k<data_per_task:
                sample = dst_train[cur]
                if sample[1] < num_classes:
                    sample = dst_train[cur]
                    images_all.append(torch.unsqueeze(sample[0], dim=0))
                    labels_all.append(sample[1])
                    total_images.append(torch.unsqueeze(sample[0], dim=0))
                    total_labels.append(sample[1])
                    class_dist[sample[1]] = class_dist[sample[1]]+1
                    k+=1
                    cur +=1
                else:
                    cur +=1
            print("task{} class distribution: ".format(i), class_dist)
            images_all = torch.cat(images_all, dim=0).to("cpu")
            labels_all = torch.tensor(labels_all, dtype=torch.long, device="cpu")
            train_datasets[i][0]=images_all
            train_datasets[i][1]=labels_all
        total_images = torch.cat(total_images, dim=0).to("cpu")
        total_labels = torch.tensor(total_labels, dtype=torch.long, device="cpu")
        total_datasets[0]=total_images
        total_datasets[1]=total_labels
        
    elif framework=='CIL':
        if num_task==3:
            task_classes={x:0 for x in range(3)}
            task_classes.update({x:1 for x in range(3,6)})
            task_classes.update({x:2 for x in range(6,10)})
        if num_task==5:
            task_classes={x:0 for x in range(2)}
            task_classes.update({x:1 for x in range(2,4)})
            task_classes.update({x:2 for x in range(4,6)})
            task_classes.update({x:3 for x in range(6,8)})
            task_classes.update({x:4 for x in range(8,10)})
        n_full_task = 0
        num_data_tasks = {x:0 for x in range(num_task)}
        class_dist = {x:{y:0 for y in range(10)} for x in range(num_task)}
        print(class_dist)
        i=0
        total_images=[]
        total_labels=[]
        while n_full_task < num_task:
            sample=dst_train[i]
            if num_data_tasks[task_classes[sample[1]]] < data_per_task:
                train_datasets[task_classes[sample[1]]][0].append(torch.unsqueeze(sample[0], dim=0))
                train_datasets[task_classes[sample[1]]][1].append(sample[1])
                total_datasets[0].append(torch.unsqueeze(sample[0], dim=0))
                total_datasets[1].append(sample[1])
                class_dist[task_classes[sample[1]]][sample[1]] +=1
                num_data_tasks[task_classes[sample[1]]]+=1
                if num_data_tasks[task_classes[sample[1]]] == data_per_task:
                    n_full_task+=1
            i+=1
        for i in range(num_task):
            train_datasets[i][0] = torch.cat(train_datasets[i][0], dim=0).to("cpu")
            train_datasets[i][1] = torch.tensor(train_datasets[i][1], dtype=torch.long, device="cpu")
            print("task{} class distribution: ".format(i), class_dist[i])
            
        total_datasets[0] = torch.cat(total_datasets[0], dim=0).to("cpu")
        total_datasets[1] = torch.tensor(total_datasets[1], dtype=torch.long, device="cpu")
    
    return train_datasets, total_datasets