import os
import pdb
import glob
import random
import numpy as np

from misc.utils import *

class DataLoader:
    """ Data Loader
        
    Loading data for the corresponding clients
    
    Created by:
        Wonyong Jeong (wyjeong@kaist.ac.kr)
    """
    def __init__(self, args):
        self.args = args
        # self.base_dir = os.path.join(self.args.task_path, self.args.task) 
        self.base_dir = os.path.join(self.args.task_path, '5_20') 
        self.did_to_dname = {
            0: 'cifar10',
            1: 'cifar100',
            2: 'mnist',
            3: 'svhn',
            4: 'fashion_mnist',
            5: 'traffic_sign',
            6: 'face_scrub',
            7: 'not_mnist',
        }

    def init_state(self, cid):
        self.state = {
            'client_id': cid,
            'tasks': []
        }
        self.load_tasks(cid)

    def load_state(self, cid):
        self.state = np_load(os.path.join(self.args.state_dir, '{}_data.npy'.format(cid))).item()

    def save_state(self):
        np_save(self.args.state_dir, '{}_data'.format(self.state['client_id']), self.state)

    def load_tasks(self, cid):
        if self.args.task in ['non_iid_50']:
            task_set = []
            for index, datasetName in enumerate(self.args.merged_datasets_names):
                num_tasks = self.args.num_tasks_per_dataset[index]
                for dataset_task_id in range(0, num_tasks):
                    task_name = datasetName + "_" + str(dataset_task_id)
                    task_set.append(task_name)

            self.state['curr_task'] = self.load_random_task_id(task_set)
            self.state['tasks'] = task_set #list of tasks, same for all clients
        
        else:
            print('no correct task was given: {}'.format(self.args.task))
            os._exit(0)

    def load_random_task(self, cid, task_set):
        random_task = random.choice(task_set)
        return "cid_" + str(cid) + "_" + random_task
    
    def load_random_task_id(self, task_set):
        random_task_id = random.choice(range(0, len(task_set)))
        return random_task_id

    def get_train(self, task_id): #OK
        modified_task_id = self.get_modified_task_id(task_id)
        return load_task(self.base_dir, modified_task_id +'_train.npy').item()
    

    def get_valid(self, task_id): #OK
        modified_task_id = self.get_modified_task_id(task_id)
        valid = load_task(self.base_dir, modified_task_id +'_valid.npy').item()
        return valid['x_valid'], valid['y_valid']

    def get_test(self, task_id): #OK -> test data for all tasks (ignoring task_id)
        print("getting test data")
        x_test_list = []
        y_test_list = []
        for tid, task in enumerate(self.state['tasks']):
            modified_task_id = self.get_modified_task_id(tid)
            test = load_task(self.base_dir, modified_task_id+'_test.npy').item()
            x_test_list.append(test['x_test'])
            y_test_list.append(test['y_test'])
        return x_test_list, y_test_list
    
    def get_modified_task_id(self, task_id):
        modified_task_id = "cid_" + str(self.state['client_id']) + "_" + self.state['tasks'][task_id] 
        return modified_task_id

#python3 ../main.py --gpu 0,1,2,3,4 --work-type train --model fedweit --task non_iid_50 --gpu-mem-multiplier 9 --num-rounds 3 --num-epochs 1 --batch-size 100 --seed 777


#python ../main.py --gpu 0,1,2,3,4 --work-type train --model fedweit --task non_iid_50 --gpu-mem-multiplier 9 --num-rounds 1 --total-rounds 10 --num-epochs 1 --batch-size 100 --seed 777