import argparse
import time
from tqdm import tqdm
import os
import random
import numpy as np
from scipy.stats import spearmanr
from scipy.stats import rankdata
from collections import OrderedDict
import wandb
import torch
import torch.nn as nn
from networks.modules import gradient_update_parameters

from init import Initial
from predictor import PredictorModel
from sampler import TaskSampler
from myutils import InfIterator
from myutils import Logger
from myutils import save_model
from myutils import get_optimizer
from myutils import get_scheduler
from myutils import str2bool



class Meta:
    def __init__(self, args, main_path, save_path):
        self.args = args
        self.main_path = main_path
        self.save_path = save_path
        ## General
        self.mode = args.mode
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.default_data_path = args.default_data_path
        self.image_size = args.image_size
        self.batch_size = args.batch_size
        self.minmax_norm = args.minmax_norm

        ## Search space
        self.search_space = args.search_space
        self.channel_mul = args.channel_mul

        ## Teacher
        self.tc_stage_num = args.tc_stage_num
        self.tc_stage_depth = args.tc_stage_depth
        self.tc_stage_default_channel_widths = args.tc_stage_default_channel_widths
        self.tc_stage_strides = args.tc_stage_strides
        self.tc_mul_seeds_on = args.tc_mul_seeds_on

        ## Student
        self.net_info_path = args.net_info_path

        ## Meta-learning
        self.mtrn_ds_split = args.mtrn_ds_split
        self.mvld_ds_split = args.mvld_ds_split
        self.mtst_datasets = args.mtst_datasets
        self.num_episodes = args.num_episodes
        self.meta_lr = args.meta_lr
        self.mvld_frequency = args.mvld_frequency
        ## Bilevel
        self.bilevel = args.bilevel
        self.num_train_updates = args.num_train_updates
        self.first_order = args.first_order
        self.step_size = args.step_size
        self.task_lr = args.task_lr
        self.n_support = args.n_support
        self.n_query = args.n_query
        self.n_support_tr = args.n_support_tr
        self.n_query_tr = args.n_query_tr
        self.meta_test_support_index = args.meta_test_support_index
        self.tc_support = args.tc_support
        self.tc_support_tr = args.tc_support_tr

        ## Save model at specific epi
        self.save_epi = args.save_epi
        ## Normalize with teacher net acc
        self.tc_norm = args.tc_norm
        ## Plot acc
        self.plot_acc = args.plot_acc
        ## PR type
        self.pr_type = args.pr_type
        ## Meta-valid On/Off
        self.mvld_on = args.mvld_on

        ## Sampler
        self.get_samplers()

        # ## Predictor
        self.model = PredictorModel(args)
        self.model = self.model.to(self.device)

        self.model_params = list(self.model.parameters())
        if self.task_lr:
            self.define_task_lr_params()
            self.model_params += list(self.task_lr.values())

        # ## Optimizer
        self.meta_optimizer = torch.optim.Adam(
            self.model_params, lr=self.meta_lr)
        self.scheduler = None

        # Criterion
        self.criterion = nn.MSELoss() 
        self.criterion = self.criterion.to(self.device)

        ## PR type
        self.pr_type = args.pr_type

        ## Meta-valid On/Off
        self.mvld_on = args.mvld_on

        ## Logs
        self.logger = Logger(
                log_dir=main_path,
                exp_name=exp_name,
                exp_suffix=exp_suffix,
                write_textfile=(False if self.args.folder_name == 'debug' else True),
                use_wandb=(False if self.args.folder_name == 'debug' else True),
                wandb_project_name=self.args.wandb_project_name,
            )
        self.logger.update_config(self.args, is_args=True)


    def define_task_lr_params(self):
        self.task_lr = OrderedDict()
        for key, val in self.model.named_parameters():
            self.task_lr[key] = nn.Parameter(
                self.step_size * torch.ones_like(val))


    def get_samplers(self):
        self.mtrn_samplers = self.mvld_samplers = self.mtst_samplers = None
        self.mtrn_samplers_for_test = None

        if self.mode == 'meta_train':
            self.mtrn_samplers = []
            self.mtrn_samplers_for_test = {}
            for ds_split in self.mtrn_ds_split:
                self.mtrn_samplers.append(\
                    TaskSampler(mode='meta_train', 
                        default_data_path=self.default_data_path,
                        ds_split=ds_split,
                        ds_name='tiny_imagenet', 
                        image_size=self.image_size,
                        batch_size=self.batch_size,
                        search_space=self.search_space,
                        tc_mul_seeds_on=self.tc_mul_seeds_on,
                        tc_stage_num=self.tc_stage_num, 
                        tc_stage_depth=self.tc_stage_depth, 
                        tc_stage_default_channel_widths=self.tc_stage_default_channel_widths, 
                        tc_stage_strides=self.tc_stage_strides,
                        channel_mul=self.channel_mul,
                        net_info_path=self.net_info_path,
                        minmax_norm=self.minmax_norm,
                        n_support=self.n_support,
                        n_query=self.n_query,
                        n_support_tr=self.n_support_tr,
                        n_query_tr=self.n_query_tr,
                        bilevel=self.bilevel,
                        user=self.user,
                        tc_support=self.tc_support,
                        tc_support_tr=self.tc_support_tr,
                        tc_norm=self.tc_norm,
                        pr_type=self.pr_type))

            if self.mvld_on:
                self.mvld_samplers = {}
                for ds_split in self.mvld_ds_split:
                    self.mvld_samplers[f'tiny_imagenet_{ds_split}'] = \
                        TaskSampler(mode='meta_train', 
                            default_data_path=self.default_data_path,
                            ds_split=ds_split,
                            ds_name='tiny_imagenet', 
                            image_size=self.image_size,
                            batch_size=self.batch_size,
                            search_space=self.search_space,
                            tc_mul_seeds_on=self.tc_mul_seeds_on,
                            tc_stage_num=self.tc_stage_num, 
                            tc_stage_depth=self.tc_stage_depth, 
                            tc_stage_default_channel_widths=self.tc_stage_default_channel_widths, 
                            tc_stage_strides=self.tc_stage_strides,
                            channel_mul=self.channel_mul,
                            net_info_path=self.net_info_path,
                            minmax_norm=self.minmax_norm,
                            n_support=self.n_support,
                            n_query=self.n_query,
                            n_support_tr=self.n_support_tr,
                            n_query_tr=self.n_query_tr,
                            bilevel=self.bilevel,
                            user=self.user,
                            tc_support=self.tc_support,
                            tc_support_tr=self.tc_support_tr,
                            tc_norm=self.tc_norm,
                            pr_type=self.pr_type,
                            meta_test_support_index=self.meta_test_support_index)
            print('==> load samplers')

        elif self.mode == 'meta_test': 
            self.mtst_samplers = {}
            for ds_name in self.mtst_datasets:
                self.mtst_samplers[f'{ds_name}'] = \
                    TaskSampler(mode='meta_test', 
                        default_data_path=self.default_data_path,
                        ds_split=ds_split,
                        ds_name=ds_name, 
                        image_size=self.image_size,
                        batch_size=self.batch_size,
                        search_space=self.search_space,
                        tc_mul_seeds_on=self.tc_mul_seeds_on,
                        tc_stage_num=self.tc_stage_num, 
                        tc_stage_depth=self.tc_stage_depth, 
                        tc_stage_default_channel_widths=self.tc_stage_default_channel_widths, 
                        tc_stage_strides=self.tc_stage_strides,
                        channel_mul=self.channel_mul,
                        net_info_path=self.net_info_path,
                        minmax_norm=self.minmax_norm,
                        n_support=self.n_support,
                        n_query=self.n_query,
                        n_support_tr=self.n_support_tr,
                        n_query_tr=self.n_query_tr,
                        bilevel=self.bilevel,
                        user=self.user,
                        tc_support=self.tc_support,
                        tc_support_tr=self.tc_support_tr,
                        tc_norm=self.tc_norm,
                        pr_type=self.pr_type,
                        meta_test_support_index=self.meta_test_support_index)


    def forward(self, task, test=False):
        ds_info, tc_net, support, query = task
        n_query = self.n_query if test else self.n_query_tr

        if test:
            with torch.no_grad():
                query_y_hat = self.model(D=ds_info, 
                                        F=tc_net, 
                                        A=query['arch_info'], 
                                        n=n_query)
        else:
            query_y_hat = self.model(D=ds_info, 
                                    F=tc_net, 
                                    A=query['arch_info'], 
                                    n=n_query
                                    )
        query_y = query['y']['final_acc'].to(self.device)

        return self.criterion(query_y_hat, query_y), query_y_hat


    def forward_bilevel(self, task, test=False):
        # Run inner loops to get adapted parameters (theta_t`)
        ds_info, tc_net, support, query = task
        n_support = self.n_support if test else self.n_support_tr
        n_query = self.n_query if test else self.n_query_tr
        if self.mode == 'meta_test':
            n_query = 10000


        params = OrderedDict(self.model.meta_named_parameters())

        for n_update in range(self.num_train_updates):
            
            self.model.zero_grad()
            support_y_hat = self.model(D=ds_info, 
                                F=tc_net, 
                                A=support['arch_info'], 
                                n=n_support,
                                params=params)
            support_y = support['y']['final_acc'].to(self.device)
            inner_loss = self.criterion(support_y_hat, support_y)
            
            if self.task_lr is not False:
                params = gradient_update_parameters(self.model,
                                            inner_loss,
                                            step_size=self.task_lr,
                                            first_order=self.first_order)
            else:
                params = gradient_update_parameters(self.model,
                                                inner_loss,
                                                step_size=self.step_size,
                                                first_order=self.first_order)
        
        if test:
            with torch.no_grad():
                query_y_hat = self.model(D=ds_info, 
                                        F=tc_net, 
                                        A=query['arch_info'], 
                                        n=n_query,
                                        params=params)
        else:
            query_y_hat = self.model(D=ds_info, 
                                    F=tc_net, 
                                    A=query['arch_info'], 
                                    n=n_query,
                                    params=params)
        query_y = query['y']['final_acc'].to(self.device)
        

        return self.criterion(query_y_hat, query_y), query_y_hat


    def meta_training(self):

        is_best = False
        loss_keys = ['tr_loss', 'va_loss', 'te_loss']
        min_info = {k: 10000000000 for k in loss_keys}
        corr_keys = ['va_spearman', 'te_spearman'] 
        min_info = {k: -1 for k in corr_keys}
        log_keys = loss_keys + corr_keys

        num_meta_batch = len(self.mtrn_samplers)
        with tqdm(total=self.num_episodes, desc=f'Meta-training') as t:
            for i_epi in range(self.num_episodes):
                st_epi_time = time.time()

                mtrn_keys = ['tr_loss']

                self.model.train()
                self.model.zero_grad()
                meta_loss = torch.tensor(0., device=self.device)
                for mb in range(num_meta_batch):
                    
                    task = self.mtrn_samplers[mb].get_random_task()
                    if self.bilevel:
                        outer_loss, _ = self.forward_bilevel(task)
                    else:
                        outer_loss, _ = self.forward(task)
                    meta_loss += outer_loss

                meta_loss = meta_loss / float(num_meta_batch) 
                self.meta_optimizer.zero_grad()
                meta_loss.backward()
                self.meta_optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step(meta_loss)

                self.logger.update(key='tr_time', v=(time.time() - st_epi_time))                    
                self.logger.update(key='tr_loss', v=meta_loss.item()/num_meta_batch) 
                
                if (i_epi + 1) % self.mvld_frequency == 0:
                    last_info = {}
                    element = {'meta_train_loss': mtrn_keys}
                    self.logger.reset(except_keys=mtrn_keys)
                    
                    self.model.eval()
                    ## Meta-valid
                    test_mode_list = ['meta_test']
                    if self.mvld_on:
                        test_mode_list.append('meta_valid')
                    for test_mode in test_mode_list:
                        if test_mode == 'meta_valid':
                            loss_dict, spearman_corr_dict, \
                                y_all_dict, y_pred_all_dict = self.meta_valid(i_epi + 1)
                        elif test_mode == 'meta_test':
                            loss_dict, spearman_corr_dict, \
                                y_all_dict, y_pred_all_dict = self.meta_test(i_epi + 1)

                        for k, v in loss_dict.items():
                            self.logger.update(k, v)
                        for k, v in spearman_corr_dict.items():
                            self.logger.update(k, v)
                        element.update({f'{test_mode}_loss': loss_dict.keys(),
                                        f'{test_mode}_spearman': spearman_corr_dict.keys()})
                        
                        last_info.update(loss_dict)
                        last_info.update(spearman_corr_dict)
                        last_info.update(y_all_dict)
                        last_info.update(y_pred_all_dict)
                        
                    is_best = min_info['te_spearman'] < last_info['te_spearman']    
                    if is_best:
                        min_info = last_info
                        print('best for meta-test is updated')
                    self.model.cpu()
                    save_model({
                                'epoch': i_epi+1,
                                'optimizer': self.meta_optimizer.state_dict(),
                                'state_dict': self.model.state_dict(),
                                'last_info': last_info,
                                'min_info': min_info,
                            }, self.save_path, is_best=is_best, model_name=None)
                    save_path_temp = os.path.join(self.main_path, f'{i_epi+1}', 'checkpoint')
                    os.makedirs(save_path_temp, exist_ok=True)
                    save_model({
                                'epoch': i_epi+1,
                                'optimizer': self.meta_optimizer.state_dict(),
                                'state_dict': self.model.state_dict(),
                                'last_info': last_info,
                                'min_info': min_info,
                            }, save_path_temp, is_best=False, model_name=None)
                    self.model.to(self.device)
                    print(f'=> save model at epi {i_epi+1}')
                    
                    
                    if self.save_epi is not None and (i_epi+1 == self.save_epi):
                        self.model.cpu()
                        save_model({
                                    'epoch': i_epi+1,
                                    'optimizer': self.meta_optimizer.state_dict(),
                                    'state_dict': self.model.state_dict(),
                                    'last_info': last_info,
                                    'min_info': min_info,
                                }, os.path.join(self.save_path, f'{self.save_epi}'), is_best=False, model_name=None)
                        self.model.to(self.device)

                    self.logger.write_log(element=element, step=i_epi+1)
                else:
                    t.set_postfix(self.logger.avg(log_keys))
                    t.update(1)
        self.logger.update_config(min_info)
        self.logger.save_log()
 

    def _test_task(self, sampler):
        task = sampler.get_test_task_w_all_samples()
        _, _, _, query = task
        y_all = query['y']['final_acc']
        if self.bilevel:
            outer_loss, y_pred_all = self.forward_bilevel(task, test=True)
        else:
            outer_loss, y_pred_all = self.forward(task, test=True)
        y_all = y_all.cpu().squeeze(1).numpy()
        y_pred_all = y_pred_all.cpu().squeeze(1).numpy()

        spearman_corr = spearmanr(y_all, y_pred_all)[0]
        return outer_loss, spearman_corr, y_all, y_pred_all


    def meta_valid(self, epi):
        head = 'va'
        loss_dict = {}
        spearman_corr_dict = {}
        y_all_dict = {}
        y_pred_all_dict = {}
        for ds_name, sampler in self.mvld_samplers.items():
            loss, spearman_corr, y_all, y_pred_all = self._test_task(sampler)
            if self.plot_acc:
                self.plot_accuracy(ds_name, y_all, y_pred_all, epi)
            ## Logs
            loss_dict[f'{head}_{ds_name}_loss'] = loss.item()
            spearman_corr_dict[f'{head}_{ds_name}_spearman'] = spearman_corr
            y_all_dict[f'{head}_{ds_name}_y_all'] = y_all
            y_pred_all_dict[f'{head}_{ds_name}_y_pred_all'] = y_pred_all
            print(f'{ds_name} | loss {loss:.3f} | spearman: {spearman_corr: .3f}')
            
        loss_dict[f'{head}_loss'] = sum(loss_dict.values()) / len(loss_dict)
        spearman_corr_dict[f'{head}_spearman'] = sum(spearman_corr_dict.values()) / len(spearman_corr_dict)
        return loss_dict, spearman_corr_dict, y_all_dict, y_pred_all_dict


    def meta_test(self, epi):
        head = 'te'
        loss_dict = {}
        spearman_corr_dict = {}
        y_all_dict = {}
        y_pred_all_dict = {}
        for ds_name, sampler in self.mtst_samplers.items():
            loss, spearman_corr, y_all, y_pred_all = self._test_task(sampler)
            if self.plot_acc:
                self.plot_accuracy(ds_name, y_all, y_pred_all, epi)
            ## Logs
            loss_dict[f'{head}_{ds_name}_loss'] = loss
            spearman_corr_dict[f'{head}_{ds_name}_spearman'] = spearman_corr
            y_all_dict[f'{head}_{ds_name}_y_all'] = y_all
            y_pred_all_dict[f'{head}_{ds_name}_y_pred_all'] = y_pred_all
            print(f'{ds_name} | loss {loss:.3f} | spearman: {spearman_corr: .3f}')
        loss_dict[f'{head}_loss'] = sum(loss_dict.values()) / len(loss_dict)
        spearman_corr_dict[f'{head}_spearman'] = sum(spearman_corr_dict.values()) / len(spearman_corr_dict)
        return loss_dict, spearman_corr_dict, y_all_dict, y_pred_all_dict


    def load_model(self, load_path=None):
        print(f'==> load predictor model from {load_path}')
        self.model.load_state_dict(torch.load(load_path)['state_dict'])
        

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--wandb_project_name', type=str, default=None) 
    parser.add_argument('--mode', type=str, default='meta_train')
    parser.add_argument('--default_data_path', type=str, default=None)
    parser.add_argument('--search_space', type=str, default='resnet')

    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--manual_seed', type=int, default=0) 
    parser.add_argument('--folder_name', type=str, default='debug')
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--net_info_path', type=str, default=None)
    
    parser.add_argument('--tc_net_name', type=str, default='resnet42')
    parser.add_argument('--tc_stage_strides', type=str, default=[1, 2, 2, 2])
    parser.add_argument('--tc_stage_depth', type=int, default=5)
    parser.add_argument('--tc_stage_num', type=int, default=4)
    parser.add_argument('--tc_mul_seeds_on', type=bool, default=False)
    parser.add_argument('--tc_stage_default_channel_widths', nargs='+', type=int, default=[16, 32, 64, 128])
    parser.add_argument('--channel_mul', type=int, default=2)
    parser.add_argument('--width_mult_list', nargs='+', type=float, 
                        default=[0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.])

    parser.add_argument('--mtst_datasets', nargs='+', type=str, 
        default=['cifar100', 'dtd', 'stl10', 'cub']) 
    parser.add_argument('--mtst_dataset_option', type=int, default=0)
    parser.add_argument('--mtrn_ds_split', nargs='+', type=int, 
        default=[0, 1, 2, 3, 4, 5, 6, 7])
    parser.add_argument('--mvld_ds_split', nargs='+', type=int, 
        default=[8, 9]) 
    parser.add_argument('--meta_lr', type=float, default=1e-4)
    parser.add_argument('--num_episodes', type=int, default=2000)
    parser.add_argument('--mvld_frequency', type=int, default=50)
    # Encoding
    parser.add_argument('--input_type', type=str, default='FA')
    parser.add_argument('--h_dim', type=int, default=64)
    parser.add_argument('--d_inp_dim', type=int, default=512)
    parser.add_argument('--d_out_dim', type=int, default=64)
    parser.add_argument('--nz', type=int, default=56)
    parser.add_argument('--num_sample', type=int, default=20)
    parser.add_argument('--f_inp_dim', type=int, default=256)
    parser.add_argument('--f_out_dim', type=int, default=64)
    parser.add_argument('--a_inp_dim', type=int, default=161)
    parser.add_argument('--a_out_dim', type=int, default=64)
    parser.add_argument('--minmax_norm', type=str2bool, default=True)
    # Bi-level
    parser.add_argument('--bilevel', type=str2bool, default=True)
    parser.add_argument('--num_train_updates', type=int, default=3)
    parser.add_argument('--num_test_updates', type=int, default=3)
    parser.add_argument('--first_order', type=str2bool, default=False)
    parser.add_argument('--step_size', type=float, default=0.001)
    parser.add_argument('--n_support', type=int, default=1)
    parser.add_argument('--n_query', type=int, default=30)
    parser.add_argument('--n_support_tr', type=int, default=1)
    parser.add_argument('--n_query_tr', type=int, default=30)
    parser.add_argument('--task_lr', type=str2bool, default=True)

    parser.add_argument('--meta_test_support_index', nargs='+', type=str, default=[0, 1, 2, 3, 4])
    parser.add_argument('--tc_support', type=str2bool, default=False)
    parser.add_argument('--tc_support_tr', type=str2bool, default=False)

    # MISC
    parser.add_argument('--use_l2norm', type=str2bool, default=True)
    parser.add_argument('--user', type=str, default='sh')
    ## Input type
    parser.add_argument('--func_type', type=int, default=0)
    parser.add_argument('--set_type', type=int, default=0)
    ## Save model for specific episode
    parser.add_argument('--save_epi', type=int, default=None)

    ## Normalize with Teacher net acc
    parser.add_argument('--tc_norm', type=str2bool, default=False)
    ## plot acc
    parser.add_argument('--plot_acc', type=str2bool, default=True)
    ## PR type
    parser.add_argument('--pr_type', type=str, default='random_init')
    
    ## Meta-valid On/Off
    parser.add_argument('--mvld_on', type=str2bool, default=False)

    ## use TANS ver. Predictor
    parser.add_argument('--m_inp_dim', type=int, default=128)
    parser.add_argument('--use_noise', type=str2bool, default=True)
    parser.add_argument('--use_attnmap', type=str2bool, default=False)
    parser.add_argument('--use_abs', type=str2bool, default=False)

    args = parser.parse_args()
    base_configs = ['common.yaml']
    initial = Initial(args, base_configs=base_configs)
    args = initial.args

    main_path = f'../exp/{args.folder_name}'
    main_path += f'/e-{args.num_episodes}/lr-{args.meta_lr}'
    main_path += f'/pr-{str(args.pr_type)[0]}'
    ## Encoder type
    main_path += f'/inp-{args.input_type}'
    if 'T' or 'S' in args.input_type and args.use_noise:
        main_path += f'/noi-{str(args.use_noise)[0]}'
        main_path += f'/attn-{str(args.use_attnmap)[0]}'
        main_path += f'/abs-{str(args.use_abs)[0]}'
    ## Normalization
    main_path += f'/tcn-{str(args.tc_norm)[0]}'

    ## Bi-level optim
    main_path += f'/bi-{str(args.bilevel)[0]}-tsk-{str(args.task_lr)[0]}'
    if args.bilevel:
        main_path += f'/tspp-{str(args.tc_support_tr)[0]}'
        main_path += f'/ntr-{args.num_train_updates}-nts-{args.num_test_updates}'
        main_path += f'/s-{args.n_support}/q-{args.n_query}'
        main_path += f'/sspp-{args.n_support_tr}-{args.n_query_tr}'


    save_path = os.path.join(main_path, 'checkpoint')
    exp_name = main_path.replace('../exp/', '')
    exp_name = exp_name.replace(f'{args.tc_net_name}/stud/', '').replace('/', '_')
    exp_suffix = "" 

    os.makedirs(main_path, exist_ok=True)
    os.makedirs(save_path, exist_ok=True)
    
    os.environ['CUDA_VISIBLE_DEVICES']=args.gpu
    args.__dict__['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)


    meta = Meta(args, main_path, save_path)
    if args.mode == 'meta_train':
        print('start meta-training')
        meta.meta_training()
    