from lib.config import cfg, cfg_from_file
from lib.utils import *
from data.dataset import *
from data.dataset import get_dataset, build_transform, Lambda
from copy import deepcopy


import numpy as np
import torch
import random
import argparse
import pprint
import random
import copy

def set_seed(seed):
    cfg.seed = seed
    torch.cuda.manual_seed_all(cfg.seed)
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

def build_vtab_sim_dataloader(batch_size=32, sim_level=50):
    dataloader = list()
    class_mask = list()

    transform_train = build_transform(True)
    transform_val = build_transform(False)

    vtab_path = cfg.dtask.data_path
    dataset_list = os.listdir(vtab_path)
    dataset_list = ['eurosat', 'oxford_flowers102', 'oxford_iiit_pet', 'resisc45']

    if cfg.dtask.shuffle:
        random.shuffle(dataset_list)

    print(dataset_list)

    vtabgroup2dts = {'Natural':['caltech101','cifar','dtd','oxford_flowers102','oxford_iiit_pet','sun397','svhn'],
                            'Structured':['clevr_dist','clevr_count','dsprites_ori','dsprites_loc',
                                        'smallnorb_ele','smallnorb_azi', 'kitti', 'dmlab'],
                            'Specialized':['patch_camelyon','diabetic_retinopathy','eurosat','resisc45',],}

    if True: #args.no_mild_tasks:
        num_no_mild_tasks = 0
        no_mild_ts = random.sample(dataset_list, k=num_no_mild_tasks)
        print('no_mild_tasks', no_mild_ts)
        for task in no_mild_ts:
            for group, dts  in vtabgroup2dts.items():
                if task in dts:
                    vtabgroup2dts[group].remove(task)
    

    dt2dtids = {dt:dt_id for dt_id,dt in enumerate(dataset_list)} # e.g., {dtd:0, ...}
    vtabgroup2dtids = {group:[ dt2dtids[dt] for dt in dts if dt in dt2dtids ] \
                                    for group, dts in vtabgroup2dts.items()}
    dtid2groupdtids = dict() # e.g., {0:[0,3,4], ...}
    for dtid in list(dt2dtids.values()):
        in_group = False
        for vtab_g, dtids in vtabgroup2dtids.items():
            if dtid in dtids:
                dtid2groupdtids[dtid] = dtids
                in_group = True

        if not in_group: # spcf task: no mild task
            dtid2groupdtids[dtid] = [dtid]

    cfg.dtask.nb_classes = 0
    num_overlapping_tasks = 4

    overlapping_inds = torch.randperm(cfg.continual.n_tasks)[:num_overlapping_tasks]   
    print('overlapping_inds', overlapping_inds)

    datasets = []
    for i in range(cfg.continual.n_tasks):
        dataset_train, dataset_val = get_dataset(dataset_list[i], transform_train, transform_val)

        transform_target = Lambda(target_transform, cfg.dtask.nb_classes)
        if class_mask is not None:
            exposed_cls = dataset_train.classes
            class_mask.append([i + cfg.dtask.nb_classes for i in range(len(exposed_cls))])
            cfg.dtask.nb_classes += len(exposed_cls)
            print('cfg.dtask.nb_classes', cfg.dtask.nb_classes)

        if not cfg.dtask.task_inc: # set to true
            dataset_train.target_transform = transform_target
            dataset_val.target_transform = transform_target

        if overlapping_inds is None:
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
            datasets += [{'data_id':i ,'data_train':dataset_train, 'data_test': dataset_val}] # for making milder_vtab
            data_loader_train = torch.utils.data.DataLoader(
                dataset_train, sampler=sampler_train,
                batch_size=batch_size, 
            )
    
            data_loader_val = torch.utils.data.DataLoader(
                dataset_val, sampler=sampler_val,
                batch_size=batch_size, 
            )

            dataloader.append({'train': data_loader_train, 'val': data_loader_val})
        else:
            n_overlaps = 1 if i in overlapping_inds else 1 #args.num_overlaps = 1
            for _ in range(n_overlaps):
                datasets += [{'data_id':i ,'data_train':dataset_train, 'data_test': dataset_val}] # for making milder_vtab

                num_subset = int(len(dataset_train) * 1 ) #args.overlap_dataset_scale)
                subset_inds = torch.randperm(len(dataset_train))[:num_subset]
                sub_dataset_train = Subset(dataset_train, subset_inds)
                sampler_train = torch.utils.data.RandomSampler(sub_dataset_train)
                sampler_val = torch.utils.data.SequentialSampler(dataset_val)

                data_loader_train = torch.utils.data.DataLoader(
                    sub_dataset_train, sampler=sampler_train,
                    batch_size=batch_size, 
                )
                data_loader_val = torch.utils.data.DataLoader(
                    dataset_val, sampler=sampler_val,
                    batch_size=batch_size, 
                )

                dataloader.append({'train': data_loader_train, 'val': data_loader_val, 'task':i})
                print("add dataloader")
            
            for _ in range(n_overlaps-1):
                class_mask.append(class_mask[-1])

    print("BEFORE: dataloader", len(dataloader))
    tracking_classes = copy.deepcopy(class_mask)
    print('tracking classes ', tracking_classes)

    new_blurry_train_datasets = []
    new_blurry_test_datasets = []
    
    if True: 
        overlap_similarity = sim_level
        
        for dt in datasets: # #datasets = n_overlaps x num_tasks
            # - get other dtids within the same group
            data_id, data_train, data_test = dt['data_id'], dt['data_train'], dt['data_test'] # data: VTAB dataset
            group_dt_ids = deepcopy(dtid2groupdtids[data_id])
            group_dt_ids.remove(data_id)
            group_datas = [{'data_id': datasets[other_dtid * n_overlaps]['data_id'], 'data_train': datasets[other_dtid * n_overlaps]['data_train'], 'data_test': datasets[other_dtid * n_overlaps]['data_test']} for other_dtid in group_dt_ids]
            group_datas_exist = False if len(group_datas) == 0 else True

            print('data_id: ', data_id, 'other_data_id: ', group_dt_ids,)
            print('100samples of data', [data_train[_][1] for _ in range(50)])
            # - get (1-n)% of this dt & n% of other dts; n: overlap_similarity

            #--- n_classes_data
            n_classes_data = len(data_train.classes) 
            print('n_classes_data ', n_classes_data)
            if group_datas_exist:
                #---- keep classes
                keep_num_classes = int( n_classes_data * ((100-overlap_similarity)/100) ) # (1-n)%
                random.shuffle(tracking_classes[data_id])
                keep_classes = tracking_classes[data_id][:keep_num_classes]
                tracking_classes[data_id] = list(set(tracking_classes[data_id]) - set(keep_classes))
                #---- update classes
                update_num_classes = n_classes_data - keep_num_classes # n%
                #---- update classes per other dt
                update_classes_per_other_dt = int(update_num_classes//len(group_datas)) # n%/len(other_dts)
                #---- update classes list
                update_classes_list = [update_classes_per_other_dt]*len(group_datas)
                update_classes_list[-1] = update_classes_list[-1] + (update_num_classes - sum(update_classes_list))
                
                    
                print('total samplings: ', n_classes_data)
                print('n_sampling : ', [keep_num_classes]+update_classes_list)
                # print('update_samples_list ', update_samples_list)
            else:
                keep_classes = n_classes_data
                update_classes_list = []
            
            n_samples_data = len(data_train.samples) 
            shuffle_indices = torch.randperm(n_samples_data)
            new_samples_train, new_samples_test = [], []
            for ind in shuffle_indices:
                if data_train[ind][1] in keep_classes:
                    new_samples_train.append((data_train.samples[ind][0], data_train[ind][1]))

            for ind in range(len(data_test.samples)):
                if data_test[ind][1] in keep_classes:
                    new_samples_test.append((data_test.samples[ind][0], data_test[ind][1]))
                    
            for other_dt, n_replace_classes_per_dt in zip(group_datas, update_classes_list):
                # print("Total samples of this other_dt ", len(other_dt.samples))
                random.shuffle(tracking_classes[other_dt['data_id']])
                replace_classes = tracking_classes[other_dt['data_id']][:n_replace_classes_per_dt]
                tracking_classes[other_dt['data_id']] = list(set(tracking_classes[other_dt['data_id']]) - set(replace_classes))
                pick_shuffle_indices = torch.randperm(len(other_dt['data_train'].samples))
                for ind in pick_shuffle_indices:
                    if other_dt['data_train'][ind][1] in replace_classes:
                        new_samples_train.append((other_dt['data_train'].samples[ind][0], other_dt['data_train'][ind][1]))
                        
                for ind in range(len(other_dt['data_test'].samples)):
                    if other_dt['data_test'][ind][1] in replace_classes:
                        new_samples_test.append((other_dt['data_test'].samples[ind][0], other_dt['data_test'][ind][1]))
            
            # update samples & nullify target_transform for new datasets
            clone_data_train = deepcopy(data_train)
            clone_data_train.samples = new_samples_train
            # print("labels of clone_data ", clone_data_labels)
            clone_data_train.target_transform = None
            new_blurry_train_datasets += [clone_data_train]

            clone_data_test = deepcopy(data_test)
            clone_data_test.samples = new_samples_test
            # print("labels of clone_data ", clone_data_labels)
            clone_data_test.target_transform = None
            new_blurry_test_datasets += [clone_data_test]
            
        # update dataloaders & class_mask
        print("new_blurry_train_datasets ", len(new_blurry_train_datasets))
        for dt_i, new_dt in enumerate(new_blurry_train_datasets):
            new_dt_test = new_blurry_test_datasets[dt_i]
            sampler_train = torch.utils.data.RandomSampler(new_dt)
            sampler_val = torch.utils.data.SequentialSampler(new_dt_test)
            
            data_loader_train = torch.utils.data.DataLoader(
                new_dt, sampler=sampler_train,
                batch_size=batch_size,
            )
            data_loader_val = torch.utils.data.DataLoader(
                new_dt_test, sampler=sampler_val,
                batch_size=batch_size,
            )

            dataloader[dt_i] = {'train': data_loader_train, 'val': data_loader_val, 'task':dt_i}
            new_class_mask = [label for path, label in new_dt.samples] # exposed_classes
            # new_class_mask = new_dt.dataset.classes
            class_mask[dt_i] = list(set(new_class_mask))

    
    if True: 
        shuffle_overlaps_inds = []
        rand_inds = torch.tensor(shuffle_overlaps_inds = []) if len(shuffle_overlaps_inds)>0 else torch.randperm(len(dataloader))
        print('rand_inds', rand_inds)
        dataloader = [ dataloader[i.item()] for i in rand_inds ]
        class_mask = [ class_mask[i.item()] for i in rand_inds ]

    if True: #args.dataset in ['overlapping-vtab-1k',]:
        print("Before) num_tasks: ", cfg.continual.n_tasks)
        num_overlapping_tasks = 4
        num_overlaps = 1
        cfg.continual.n_tasks = cfg.continual.n_tasks + (num_overlapping_tasks) * (num_overlaps-1)
        print("After) num_tasks: ", cfg.continual.n_tasks)
        print([dataloader[i]['task'] for i in range(len(dataloader))])

    return dataloader, class_mask, new_blurry_train_datasets, new_blurry_test_datasets

def main():
    parser = argparse.ArgumentParser(description='PROTEUS in Continual Learning')
    parser.add_argument('--cfg', dest='cfg_file', default='./config/cifar-100.yml')

    args = parser.parse_args()

    if args.cfg_file is not None:
        cfg_from_file(args.cfg_file)

    log(pprint.pformat(cfg))

    gpu_list = cfg.gpu_ids.split(',')
    gpus = [int(iter) for iter in gpu_list]
    cfg.device = torch.device('cuda:' + str(gpus[0]))

    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    set_seed(cfg.seed)
    data_loader, class_mask = build_continual_dataloader(batch_size=cfg.dtask.batch_size)
    cfg.continual.n_tasks = 4
    data_loader, class_mask, new_blurry_train_datasets, new_blurry_test_datasets = build_vtab_sim_dataloader(cfg.dtask.batch_size)
    batches = [batch for batch in data_loader]
    torch.save(batches,  cfg.dtask.data_path + "/dataloader_output.pt")
    np.save(cfg.dtask.data_path +  "/class_mask.npy", class_mask)
    
if __name__ == '__main__':
    main()