import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import time
import os
import glob
import random

import configs
import backbone
from data.datamgr import SimpleDataManager, SetDataManager
from methods.maml_moml import MAML_MOML
from methods.protonet_moml import ProtoNet_MOML
from methods.boil_moml import BOIL_MOML
from methods.maml_test import MAML_Test
from methods.constrained_meta import Constrained_meta
from methods.constrained_implicit import Constrained_implicit
from io_utils import model_dict, parse_args, get_resume_file  

import csv

#os.environ['CUDA_VISIBLE_DEVICES'] = '2'

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

seed=200
setup_seed(seed)

def train(base_loader, val_loader, model, optimization, start_epoch, stop_epoch, params):    
    if optimization == 'Adam':
        optimizer = torch.optim.Adam(model.parameters())
    else:
       raise ValueError('Unknown optimization, please define by yourself')

    max_acc = 0   
    load_acc = []  

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, stop_epoch, eta_min=0, last_epoch=-1)
    # stop_epoch = 1
    for epoch in range(start_epoch,stop_epoch):
        
        model.train()
        model.train_loop(epoch, base_loader,  optimizer) #model are called by reference, no need to return 
        #scheduler.step()
        model.eval()

        if not os.path.isdir(params.checkpoint_dir):
            os.makedirs(params.checkpoint_dir)

        
        acc,acc2,B2 = model.test_loop( val_loader) 
        

        if acc > max_acc : #for baseline and baseline++, we don't use validation in default and we let acc = -1, but we allow options to validate with DB index
            print("best model! save...")
            max_acc = acc
            outfile = os.path.join(params.checkpoint_dir, 'best_model.tar')
            torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)

        if (epoch % params.save_freq==0) or (epoch==stop_epoch-1):
            outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch))
            torch.save({'epoch':epoch, 'state':model.state_dict()}, outfile)
            outfilecvs= os.path.join(params.checkpoint_dir, str(seed)+'.cvs')
            if epoch==0:
                f = open(outfilecvs, "w")
                f.close()
            with open(outfilecvs, 'a', encoding='utf-8', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([str(epoch), str(acc),str(acc2),str(B2)])

    return model

if __name__=='__main__':
    np.random.seed(10)
    params = parse_args('train')
    params.train_aug=True
    print(params)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(params.device)
    base_file = configs.data_dir[params.dataset] + 'base.json' 
    val_file   = configs.data_dir[params.dataset] + 'val.json' 
         
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224

    optimization = 'Adam'

    if params.stop_epoch == -1: 
        if params.n_shot == 1:
            params.stop_epoch = 2400
        elif params.n_shot == 5:
            params.stop_epoch = 1200
        else:
            params.stop_epoch = 1000 #default
     
    if params.method in ['constrained_meta','constrained_implicit','maml_moml', 'protonet_moml', 'boil_moml', 'maml_test']:
        n_query = max(1, int(16* params.test_n_way/params.train_n_way)) #if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small
 
        train_few_shot_params    = dict(n_way = params.train_n_way, n_support = params.n_shot) 
        base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
        base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )
         
        test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot) 
        val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
        val_loader              = val_datamgr.get_data_loader( val_file, aug = False) 
        #a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor        


        if params.method == 'constrained_implicit':
            backbone.ConvBlock.maml = False
            backbone.SimpleBlock.maml = False
            backbone.BottleneckBlock.maml = False
            backbone.ResNet.maml = False
            model = Constrained_implicit(  model_dict[params.model], **train_few_shot_params )
            if params.weighting_mode=='SOML' or params.weighting_mode=='COML':
                model.weighting_mode = params.weighting_mode
            else:
                model.weighting_mode = 'COML'
            model.meta_lambda=1.0
            if params.n_shot==1:
                model.meta_lambda=8.0
                if params.dataset== 'CUB':
                    model.meta_lambda=7.0

        elif params.method == 'constrained_meta':
            backbone.ConvBlock.maml = False
            backbone.SimpleBlock.maml = False
            backbone.BottleneckBlock.maml = False
            backbone.ResNet.maml = False
            model = Constrained_meta(  model_dict[params.model], **train_few_shot_params )
            model.weighting_mode = 'COML'
            model.meta_lambda=1.0
            if params.n_shot==1:
                model.meta_lambda=8.0
                if params.dataset== 'CUB':
                    model.meta_lambda=7.0 

        elif params.method == 'maml_moml':
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML_MOML(  model_dict[params.model], approx = (params.method == 'maml_moml_appro'), **train_few_shot_params )
            model.weighting_mode = params.weighting_mode
        elif params.method == 'protonet_moml':
            model = ProtoNet_MOML( model_dict[params.model], **train_few_shot_params )
            model.weighting_mode = params.weighting_mode
        elif params.method == 'boil_moml':
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = BOIL_MOML( model_dict[params.model], approx = (params.method == 'boil_moml'), **train_few_shot_params )
            model.weighting_mode = params.weighting_mode
        elif params.method == 'maml_test':
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True 
            model = MAML_Test( model_dict[params.model], approx = (params.method == 'boil_moml'), **train_few_shot_params )
            model.weighting_mode = params.weighting_mode
    else:
        raise ValueError('Unknown method')

    model = model.cuda()
    
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s_%s_%s' %(configs.save_dir, params.dataset, params.model, params.method, params.weighting_mode, params.mark)

    if params.train_aug:
        params.checkpoint_dir += '_aug'

    params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch

    if params.method == 'maml_robust' :
        stop_epoch = params.stop_epoch * model.n_task 

    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch']+1
            model.load_state_dict(tmp['state'])
    elif params.warmup: #We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None: 
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace("feature.","")  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'  
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader,  model, optimization, start_epoch, stop_epoch, params)
