import pickle
import pdb
import pandas as pd
import ast
import torch
from torch.utils.data import Dataset

from collections import defaultdict

from networks import TEACHER_NET_NAME
from networks.resnet_lit import get_resnet
from myutils import InfIterator
from get_dataloader import get_dataloader


class TaskSampler():
    def __init__(self, mode='meta_train', 
        default_data_path=None,
        ds_split=1, ds_name='tiny_imagenet',
        image_size=64, batch_size=64, search_space='resnet', 
        tc_mul_seeds_on=False, tc_stage_num=4, tc_stage_depth=5, 
        tc_stage_default_channel_widths=[16, 32, 64, 128], tc_stage_strides=[1, 2, 2, 2],
        channel_mul=2,
        net_info_path=None,
        minmax_norm=False,
        n_support=1,
        n_query=30,
        n_support_tr=5,
        n_query_tr=30,
        bilevel=True,
        user='sh',
        tc_support=False,
        tc_support_tr=False,
        tc_norm=False,
        pr_type='random_init',
        meta_test_support_index=None,):

        if 'tiny_imagenet' in ds_name:
            if not mode in ['meta_train', 'meta_valid']:
                raise ValueError(mode, ds_name)
        else:
            if mode in ['meta_train', 'meta_valid']:
                raise ValueError(mode, ds_name)

        ## General
        self.mode = mode
        self.default_data_path = default_data_path
        self.ds_name = ds_name
        self.ds_split = ds_split
        self.image_size = image_size
        self.batch_size = batch_size

        # minmax normalization
        self.minmax_norm = minmax_norm
        self.max_arch_idx = 10
        self.min_arch_idx = 5

        ## Search space
        self.search_space = search_space
        self.channel_mul = channel_mul

        ## Bi-level
        self.bilevel = bilevel
        self.n_support = n_support
        self.n_query = n_query
        self.n_support_tr = n_support_tr
        self.n_query_tr = n_query_tr
        self.meta_test_support_index = meta_test_support_index
        self.tc_support = tc_support
        if self.tc_support and self.n_support != 1:
            raise ValueError(f'if use tc net as suppport set, set n_support as 1, not {self.n_support}')
        self.tc_support_tr = tc_support_tr
        ## MISC
        self.user = user
        ## PR type
        self.pr_type = pr_type

        ## Teacher
        self.tc_stage_num = tc_stage_num
        self.tc_stage_depth = tc_stage_depth
        self.tc_net_name = TEACHER_NET_NAME[self.tc_stage_depth]
        self.tc_stage_default_channel_widths = tc_stage_default_channel_widths
        self.tc_stage_channel_widths = [int(self.channel_mul * w) for w in self.tc_stage_default_channel_widths]
        self.tc_stage_strides = tc_stage_strides
        self.tc_mul_seeds_on = tc_mul_seeds_on

        ## Student
        self.net_info_path = net_info_path
        self.net_info_list = torch.load(net_info_path)
        self.get_stnet_final_info_dict()
        if self.mode in ['meta_valid', 'meta_test']:
            self.get_support_stnet_final_info_dict()
        if self.tc_support or self.tc_support_tr or self.tc_norm:
            self.get_support_tcnet_final_info_dict()
        ## Dataloader
        self.train_loader, self.valid_loader, self.n_classes = get_dataloader(
            self.mode, self.default_data_path, self.image_size, 
                self.batch_size, self.ds_name, self.ds_split)
        self.train_img_data_iter = InfIterator(self.train_loader)
        self.valid_img_data_iter = InfIterator(self.valid_loader)

        ## Load teacher network
        self.load_tc_net()
        
        ## Normalize with teacher net
        self.tc_norm = tc_norm

        print(f'==> load {ds_name}-{ds_split} task sampler')


    def get_support_tcnet_final_info_dict(self):
        self.support_tcnet_final_info = {}
        df = pd.read_csv(DATAPATH)
        if self.ds_name == 'tiny_imagenet':
            ds_key = f'{self.ds_name}-{self.ds_split}'
        else:
            ds_key = f'{self.ds_name}'
        self.support_tcnet_final_info['depth_config'] = df['depth_config'].tolist()
        self.support_tcnet_final_info['depth_config'] = [ast.literal_eval(i) for i in self.support_tcnet_final_info['depth_config']]
        self.support_tcnet_final_info['channel_widths'] = df['channel_widths'].tolist()
        self.support_tcnet_final_info['channel_widths'] = [ast.literal_eval(i) for i in self.support_tcnet_final_info['channel_widths']]
        self.support_tcnet_final_info['final_acc'] = [_ / 100. for _ in df[f'{ds_key}_final'].tolist()]
        self.support_tcnet_final_info['best_acc'] = [_ / 100. for _ in df[f'{ds_key}_best'].tolist()]
        self.support_tcnet_final_info['final_loss'] = df[f'{ds_key}_final_loss'].tolist()


    def get_support_stnet_final_info_dict(self):
        self.support_stnet_final_info = {}
        ## only for meta-valid and meta-test
        if self.ds_name == 'tiny_imagenet' and self.mode == 'meta_valid':
            df = pd.read_csv(DATAPATH)
            ds_key = f'{self.ds_name}-{self.ds_split}'
        else:
            df = pd.read_csv(DATAPATH)
            ds_key = f'{self.ds_name}'
        self.support_stnet_final_info['net_index'] = df['net_index'].tolist()
        self.support_stnet_final_info['depth_config'] = df['depth_config'].tolist()
        self.support_stnet_final_info['depth_config'] = [ast.literal_eval(i) for i in self.support_stnet_final_info['depth_config']]
        self.support_stnet_final_info['channel_widths'] = df['channel_widths'].tolist()
        self.support_stnet_final_info['channel_widths'] = [ast.literal_eval(i) for i in self.support_stnet_final_info['channel_widths']]
        self.support_stnet_final_info['final_acc'] = [_ / 100. for _ in df[f'{ds_key}_final'].tolist()]
        self.support_stnet_final_info['best_acc'] = [_ / 100. for _ in df[f'{ds_key}_best'].tolist()]
        self.support_stnet_final_info['final_loss'] = df[f'{ds_key}_final_loss'].tolist()


    def get_stnet_final_info_dict(self):
        if self.ds_name == 'tiny_imagenet' and self.mode == 'meta_train':
            df = pd.read_csv(DATAPATH)
            self.stnet_final_info = {}
            self.stnet_final_info['net_index'] = df['net_index'].tolist()
            self.stnet_final_info['depth_config'] = df['depth_config'].tolist()
            self.stnet_final_info['depth_config'] = [ast.literal_eval(i) for i in self.stnet_final_info['depth_config']]
            self.stnet_final_info['channel_widths'] = df['channel_widths'].tolist()
            self.stnet_final_info['channel_widths'] = [ast.literal_eval(i) for i in self.stnet_final_info['channel_widths']]
            self.stnet_final_info['final_acc'] = [_ / 100. for _ in df['finetuning_60_epoch_val_acc'].tolist()]
            self.stnet_final_info['best_acc'] = [_ / 100. for _ in df['finetuning_best_acc'].tolist()]
            self.stnet_final_info['final_loss'] = df['finetuning_60_epoch_val_loss'].tolist()

        elif self.ds_name == 'tiny_imagenet' and self.mode == 'meta_valid':
            if self.pr_type == 'random_init':
                # stnet_final_info['tiny_imagenet-{task}'][net_index] = {}
                with open (DATAPATH, 'rb') as f:
                    self.stnet_final_info = pickle.load(f)[f'{self.ds_name}-{self.ds_split}']
            elif self.pr_type == 'copy_paste_first':
                df = pd.read_csv(DATAPATH)
                self.stnet_final_info = {}
                self.stnet_final_info['net_index'] = df['net_index'].tolist()
                self.stnet_final_info['depth_config'] = df['depth_config'].tolist()
                self.stnet_final_info['depth_config'] = [ast.literal_eval(i) for i in self.stnet_final_info['depth_config']]
                self.stnet_final_info['channel_widths'] = df['channel_widths'].tolist()
                self.stnet_final_info['channel_widths'] = [ast.literal_eval(i) for i in self.stnet_final_info['channel_widths']]
                self.stnet_final_info['final_acc'] = [_ / 100. for _ in df['finetuning_60_epoch_val_acc'].tolist()]
                self.stnet_final_info['best_acc'] = [_ / 100. for _ in df['finetuning_best_acc'].tolist()]
                self.stnet_final_info['final_loss'] = df['finetuning_60_epoch_val_loss'].tolist()
        else: ## Meta-test
            if self.pr_type == 'random_init':
                df = pd.read_csv(DATAPATH)
                
            elif self.pr_type == 'copy_paste_first':
                df = pd.read_csv(DATAPATH)
            self.stnet_final_info = {}
            for appendix in ['final', 'best']:
                self.stnet_final_info[f'{appendix}_acc'] = [_ / 100. for _ in df[f'{self.ds_name}_{appendix}'].tolist()]
            self.stnet_final_info['final_loss'] = [0.0 for _ in range(len(self.net_info_list))]
        self.num_stnet = len(self.stnet_final_info['final_acc'])

    def minmax(self, k, v):
        max_v = max(self.stnet_final_info[k])
        min_v = min(self.stnet_final_info[k])
        v = (v - min_v) / (max_v - min_v) #* portion + (1 - portion) / 2
        return v

    def tcnorm(self, k, v):
        tc_v = self.support_tcnet_final_info[k][0]
        v /= tc_v
        return v

    def load_tc_net(self):
        tc_depth_config = [self.tc_stage_depth] * self.tc_stage_num
        tc_channel_widths = [[w] * self.tc_stage_depth for w in self.tc_stage_channel_widths]
        self.tc_net = get_resnet(self.n_classes, 
                                depth_config=tc_depth_config,
                                channel_widths=tc_channel_widths, 
                                stage_strides=self.tc_stage_strides, 
                                tc_stage_channel_widths=self.tc_stage_channel_widths)

        if self.tc_mul_seeds_on:
            raise NotImplementedError
        else:
            self.tc_net.load_state_dict(torch.load(tc_net_ckpt_path)['state_dict'])


    def get_tc_support_set(self):
        support = {'arch_info': {
                                    'depth_config': [],
                                    'channel_widths': []
                                },
                    'y': {
                            'final_acc':[],
                            'best_acc': [],
                            'final_loss': [],
                            }}
        support['arch_info']['depth_config'] = self.support_tcnet_final_info['depth_config']
        support['arch_info']['channel_widths'] = self.support_tcnet_final_info['channel_widths']
        # y info
        final_acc_s = [torch.tensor(_) for _ in self.support_tcnet_final_info['final_acc']]
        best_acc_s =  [torch.tensor(_) for _ in self.support_tcnet_final_info['best_acc']]
        final_loss_s = [torch.tensor(_) for _ in self.support_tcnet_final_info['final_loss']]
        if self.minmax_norm:
            final_acc_s = [self.minmax('final_acc', fc) for fc in final_acc_s]
            best_acc_s = [self.minmax('best_acc', ba) for ba in best_acc_s]
        if self.tc_norm:
            final_acc_s = [self.tcnorm('final_acc', fc) for fc in final_acc_s]
            best_acc_s = [self.tcnorm('best_acc', ba) for ba in best_acc_s]
        support['y']['final_acc'] = torch.stack(final_acc_s[:self.n_support]).view(-1, 1)
        support['y']['best_acc'] = torch.stack(best_acc_s[:self.n_support]).view(-1, 1)
        support['y']['final_loss'] = torch.stack(final_loss_s[:self.n_support]).view(-1, 1)
        return support


    def get_random_task(self):
        x, _ = next(self.train_img_data_iter)
        ds_info = {
            'ds_name': self.ds_name,
            'ds_split': self.ds_split,
            'ds_imgs': x
        }
        
        index_list = torch.randperm(self.num_stnet)[:self.n_support_tr+ self.n_query_tr]

        depth_config = [self.stnet_final_info['depth_config'][index] for index in index_list]
        channel_widths = [self.stnet_final_info['channel_widths'][index] for index in index_list]
        final_acc = [torch.tensor(self.stnet_final_info['final_acc'][index]) for index in index_list]
        best_acc = [torch.tensor(self.stnet_final_info['best_acc'][index]) for index in index_list]
        final_loss = [torch.tensor(self.stnet_final_info['final_loss'][index]) for index in index_list]
        if self.minmax_norm:
            final_acc = [self.minmax('final_acc', fc) for fc in final_acc]
            best_acc = [self.minmax('best_acc', ba) for ba in best_acc]
        if self.tc_norm:
            final_acc = [self.tcnorm('final_acc', fc) for fc in final_acc]
            best_acc = [self.tcnorm('best_acc', ba) for ba in best_acc]

        support = {'arch_info': {
                                    'depth_config': depth_config[:self.n_support_tr],
                                    'channel_widths': channel_widths[:self.n_support_tr]
                                },
                    'y': {
                            'final_acc': torch.stack(final_acc[:self.n_support_tr]).view(-1, 1),
                            'best_acc': torch.stack(best_acc[:self.n_support_tr]).view(-1, 1),
                            'final_loss': torch.stack(final_loss[:self.n_support_tr]).view(-1, 1),
                            }}
        if self.tc_support_tr:
            support = self.get_tc_support_set()

        query = {'arch_info': {
                                    'depth_config': depth_config[self.n_support_tr:],
                                    'channel_widths': channel_widths[self.n_support_tr:]
                                },
                    'y': {
                            'final_acc': torch.stack(final_acc[self.n_support_tr:]).view(-1, 1),
                            'best_acc': torch.stack(best_acc[self.n_support_tr:]).view(-1, 1),
                            'final_loss': torch.stack(final_loss[self.n_support_tr:]).view(-1, 1),
                            }}

        return ds_info, self.tc_net, support, query


    def get_test_task_w_all_samples(self):
        x, _ = next(self.train_img_data_iter)
        ds_info = {
            'ds_name': self.ds_name,
            'ds_split': self.ds_split,
            'ds_imgs': x
        }
        
        support = {'arch_info': {
                                    'depth_config': [],
                                    'channel_widths': []
                                },
                    'y': {
                            'final_acc':[],
                            'best_acc': [],
                            'final_loss': [],
                            }}
        query = {'arch_info': {
                                    'depth_config': [],
                                    'channel_widths': []
                                },
                    'y': {
                            'final_acc':[],
                            'best_acc': [],
                            'final_loss': [],
                            }}

        ## Support set
        if self.tc_support:
            support = self.get_tc_support_set()
        else:
            support['arch_info']['depth_config'] = self.support_stnet_final_info['depth_config']
            support['arch_info']['channel_widths'] = self.support_stnet_final_info['channel_widths']
            # y info
            final_acc_s = [torch.tensor(_) for _ in self.support_stnet_final_info['final_acc']]
            best_acc_s =  [torch.tensor(_) for _ in self.support_stnet_final_info['best_acc']]
            final_loss_s = [torch.tensor(_) for _ in self.support_stnet_final_info['final_loss']]
            if self.minmax_norm:
                final_acc_s = [self.minmax('final_acc', fc) for fc in final_acc_s]
                best_acc_s = [self.minmax('best_acc', ba) for ba in best_acc_s]
            if self.tc_norm:
                final_acc_s = [self.tcnorm('final_acc', fc) for fc in final_acc_s]
                best_acc_s = [self.tcnorm('best_acc', ba) for ba in best_acc_s]
            support['y']['final_acc'] = final_acc_s
            support['y']['best_acc'] = best_acc_s
            support['y']['final_loss'] = final_loss_s

        ## Query set
        # arch info
        for idx in range(self.num_stnet): # 30
            if not 'depth_config' in self.stnet_final_info.keys():
                query['arch_info']['depth_config'].append(self.net_info_list[idx][2])
                query['arch_info']['channel_widths'].append(self.net_info_list[idx][3])
            else:
                query['arch_info']['depth_config'].append(self.stnet_final_info['depth_config'][idx])
                query['arch_info']['channel_widths'].append(self.stnet_final_info['channel_widths'][idx])
        # y info
        final_acc_q = [torch.tensor(_) for _ in self.stnet_final_info['final_acc']]
        best_acc_q =  [torch.tensor(_) for _ in self.stnet_final_info['best_acc']]
        final_loss_q = [torch.tensor(_) for _ in self.stnet_final_info['final_loss']]
        if self.minmax_norm:
            final_acc_q = [self.minmax('final_acc', fc) for fc in final_acc_q]
            best_acc_q = [self.minmax('best_acc', ba) for ba in best_acc_q]
        if self.tc_norm:
            final_acc_q = [self.tcnorm('final_acc', fc) for fc in final_acc_q]
            best_acc_q = [self.tcnorm('best_acc', ba) for ba in best_acc_q]
        query['y']['final_acc'] = final_acc_q
        query['y']['best_acc'] = best_acc_q
        query['y']['final_loss'] = final_loss_q

        if not self.tc_support:
            support['y']['final_acc'] = torch.stack(support['y']['final_acc']).view(-1, 1)
            support['y']['best_acc'] = torch.stack(support['y']['best_acc']).view(-1, 1)
            support['y']['final_loss'] = torch.stack(support['y']['final_loss']).view(-1, 1)

        query['y']['final_acc'] = torch.stack(query['y']['final_acc']).view(-1, 1)
        query['y']['best_acc'] = torch.stack(query['y']['best_acc']).view(-1, 1)
        query['y']['final_loss'] = torch.stack(query['y']['final_loss']).view(-1, 1)

        return ds_info, self.tc_net, support, query

    def get_nas_task(self):
        x, _ = next(self.train_img_data_iter)
        ds_info = {
            'ds_name': self.ds_name,
            'ds_split': self.ds_split,
            'ds_imgs': x
        }
        
        support = {'arch_info': {
                                    'depth_config': [],
                                    'channel_widths': []
                                },
                    'y': {
                            'final_acc':[],
                            'best_acc': [],
                            'final_loss': [],
                            }}
        query = {'arch_info': {
                                    'depth_config': [],
                                    'channel_widths': []
                                }}

        ## Support set
        if self.tc_support:
            support = self.get_tc_support_set()
        else:
            support['arch_info']['depth_config'] = self.support_stnet_final_info['depth_config']
            support['arch_info']['channel_widths'] = self.support_stnet_final_info['channel_widths']
            # y info
            final_acc_s = [torch.tensor(_) for _ in self.support_stnet_final_info['final_acc']]
            best_acc_s =  [torch.tensor(_) for _ in self.support_stnet_final_info['best_acc']]
            final_loss_s = [torch.tensor(_) for _ in self.support_stnet_final_info['final_loss']]
            if self.minmax_norm:
                final_acc_s = [self.minmax('final_acc', fc) for fc in final_acc_s]
                best_acc_s = [self.minmax('best_acc', ba) for ba in best_acc_s]
            if self.tc_norm:
                final_acc_s = [self.tcnorm('final_acc', fc) for fc in final_acc_s]
                best_acc_s = [self.tcnorm('best_acc', ba) for ba in best_acc_s]
            support['y']['final_acc'] = final_acc_s
            support['y']['best_acc'] = best_acc_s
            support['y']['final_loss'] = final_loss_s

        ## Query set
        # arch info
        net_info_list = torch.load(DATAPATH)
        for idx in range(len(net_info_list)): 
            query['arch_info']['depth_config'].append(net_info_list[idx][2])
            query['arch_info']['channel_widths'].append(net_info_list[idx][3])

        if not self.tc_support:
            support['y']['final_acc'] = torch.stack(support['y']['final_acc']).view(-1, 1)
            support['y']['best_acc'] = torch.stack(support['y']['best_acc']).view(-1, 1)
            support['y']['final_loss'] = torch.stack(support['y']['final_loss']).view(-1, 1)

        return ds_info, self.tc_net, support, query
    