import argparse
from asyncore import write
import datetime
import json
import math
import os
import random
from re import L
import time
from collections import OrderedDict
from unittest import result
#from sqlalchemy import true

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm, trange
from torch.utils.data import DataLoader

from tensorboardX import SummaryWriter

import sys
sys.path.append('..')
from models import meta_sparse

from config.configuration import *
import models
from util import enlist_transformation
from data_generate.dataset import FewShotImageDataset
from data_generate.sampler import SuppQueryBatchSampler

class MAML(object):
    def __init__(self, model, inner_loop_params, optimizer_theta=None,
                 optimizer_mask=None, 
                 loss_function=F.cross_entropy, config=None):

        self.model = model
        self.optimizer_theta = optimizer_theta
        self.optimizer_mask = optimizer_mask
        self.loss_function = loss_function
        self.inner_loop_params = inner_loop_params
        self.config=config

    def accuracy(self,logits,targets):
        with torch.no_grad():
            _, predictions=torch.max(logits, dim=1)
            accuracy=torch.mean(predictions.eq(targets).float())
        return accuracy.item()

    def step(self, support_img, support_lbl, query_img, query_lbl, evaluation=False):
        outer_loss, outer_accuracy, counter=0., 0., 0.
        mask = None
            
        counter += 1
        #DATA
        # train_inputs = task[0].to(args.device)
        # train_targets = task[1].to(args.device) 
        # test_inputs = task[2].to(args.device)
        # test_targets = task[3].to(args.device)
        #MASK
        if self.config['gradient_mask'] or self.config['weight_mask']:
            mask = self.mask.forward()

        params = OrderedDict()
        for (name, param) in self.model.named_parameters():
            if config['weight_mask'] and name in mask:
                params[name] = param*mask[name]
            else:
                params[name] = param
            
        #INNER LOOP
        self.model.zero_grad()
        self.optimizer_theta.zero_grad()
        if config['gradient_mask'] or config['weight_mask']:
            self.mask.zero_grad()
            self.optimizer_mask.zero_grad()
                    
        ra = config['gradient_step_sampling']

        for t in range(self.config['inner_args']['n_step']):
            train_logits = self.model(support_img, params=params)
            inner_loss = self.loss_function(train_logits, support_lbl)
            self.model.zero_grad()
            grads=torch.autograd.grad(inner_loss, params.values(), 
                                        create_graph=False)
            params_next=OrderedDict()
            for (name, param), grad in zip(list(params.items()), grads):
                if self.config['gradient_mask'] and name in mask and \
                                            name in self.inner_loop_params:
                    # if args.meta_relu_through or args.meta_sgd_linear or \
                    #     args.meta_relu or args.meta_exp:
                    #     params_next[name] = \
                    #             param-(mask[name]*grad)
                    # else:
                    params_next[name] = \
                                param-config['step_size']*(mask[name]*grad)
                elif config['weight_mask'] and name in mask and \
                                            name in self.inner_loop_params:
                    params_next[name] = \
                                (param-config['step_size']*grad)*mask[name]
                elif name in self.inner_loop_params:
                    params_next[name] = param-config['step_size']*grad
                else:
                    # No inner loop adaptation 
                    params_next[name] = param
            params=params_next
        
        test_logit = self.model(query_img, params=params)
        outer_loss += self.loss_function(test_logit, query_lbl)
        outer_accuracy += self.accuracy(test_logit, query_lbl)
        
        outer_accuracy = float(outer_accuracy)/counter

        if evaluation:
            # We assume that the test_batch_size is set to 1
            return outer_accuracy

        # crazy zero gradding
        self.optimizer_theta.zero_grad()
        self.model.zero_grad()
        if self.config['gradient_mask'] or self.config['weight_mask']:
            self.mask.zero_grad()
            self.optimizer_mask.zero_grad()
        
        # backward and step
        outer_loss.backward()
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad.data.clamp_(-10, 10)

        self.optimizer_theta.step()
        if self.config['gradient_mask'] or self.config.weight_mask:
            self.optimizer_mask.step()

        return outer_loss.detach(), outer_accuracy, mask

    def train(self, dataloader, max_batches=500, epoch=None, writer=None):
        
        # Training for one epoch  
        num_batches = 0
        acc, loss = 0., 0.

        seq_index = np.random.choice(len(dataloader), len(dataloader), replace=False).tolist()
        print("{}:{} dataset".format(epoch, seq_index))

        loss_itmes = torch.tensor(0.).to(self.config['device'])
        acc_items = []

        for index_ in seq_index: 
            # train one epoch
            for batch_idx, (images, labels) in enumerate(dataloader[index_], 0):
                imgs = images.to(self.config['device'])
                lbls = labels.to(self.config['device'])
                
                supp_idx = self.config['num_way']*self.config['num_shot']
                support_img, query_img = imgs[:supp_idx], imgs[supp_idx:]
                support_lbl, query_lbl = lbls[:supp_idx], lbls[supp_idx:]

                l, a, masks = self.step(support_img, support_lbl, query_img, query_lbl)
                loss += l
                acc += a
                num_batches += 1

        acc /= num_batches * len(seq_index)
        loss /= num_batches * len(seq_index)

        # Write some stats
        mean_sparsity = 0.
        with torch.no_grad():
            writer.add_scalar('Training Loss', loss, epoch)
            writer.add_scalar('Training Accuracy', acc, epoch)
            if masks is not None:
                mean_sparsity_z, mean_sparsity_n = 0., 0.
                for k in masks:
                    cur_sparsity_z = np.count_nonzero(masks[k].\
                                        detach().cpu().numpy())
                    cur_sparsity_n = np.prod(masks[k].shape)

                    mean_sparsity_z += cur_sparsity_z
                    mean_sparsity_n += cur_sparsity_n

                    proc_ones = cur_sparsity_z/cur_sparsity_n
                    writer.add_scalar('zeros (%) in group ' + k, 
                    100 - proc_ones*100, epoch)

                mean_sparsity = 1 - mean_sparsity_z/mean_sparsity_n
                writer.add_scalar('mean zeros (%) ' + k, 
                                    mean_sparsity*100, epoch)
        
        return mean_sparsity

    def evaluate(self, dataloader, num_batches, epoch, test=False, writer=None):

        self.model.eval()
        all_loss_mean = []
        all_loss_std = []
        all_acc_mean = []
        all_acc_std = []

        #* for each dataset
        for index_ in range(len(dataloader)):

            testloader = dataloader[index_]

            #* define the variable to store the result from the same distribution
            loss_one_dis = []
            accuracy_one_dis = []
            
            for images, labels in testloader:
                
                # divide the data intp the device
                imgs = images.to(config['device'])
                lbls = labels.to(config['device'])
                supp_idx = config['num_way']*config['num_shot']
                support_img, query_img = imgs[:supp_idx], imgs[supp_idx:]
                support_lbl, query_lbl = lbls[:supp_idx], lbls[supp_idx:]
                

                acc = self.step(support_img, support_lbl, query_img, query_lbl, evaluation=True)
                accuracy_one_dis.append(acc)
            
            acc_mean = np.mean(accuracy_one_dis)
            sqrt_nsample = math.sqrt(self.config['num_val_task'])
            acc_95ci = 1.96 * np.std(accuracy_one_dis, ddof=1) / sqrt_nsample

            writer.add_scalar(
                tag='accuracy_meta_eval_task{}'.format(index_),
                scalar_value=acc_mean, global_step=epoch
            )
            writer.add_scalar(
                tag='acc95ci_meta_eval_task{}'.format(index_),
                scalar_value=acc_95ci, global_step=epoch
            )

            all_acc_mean.append(acc_mean)
            all_acc_std.append(acc_95ci)

        # some logging          
        writer.add_scalar(
            tag='accuracy_meta_eval_task_average',
            scalar_value=np.mean(all_acc_mean), global_step=epoch
        )
        writer.add_scalar(
            tag='acc95ci_meta_eval_task_average',
            scalar_value=np.mean(all_acc_std), global_step=epoch
        )
        
        return all_acc_mean, all_acc_std

def main(config, run_spec):

    begin_time = time.time()

    config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]
    config['classifier_args']['n_way'] = config['num_way']

    #* set the save file path
    start_datetime = datetime.datetime.now()
    experiment_date = '{:%Y-%m-%d_%H:%M:%S}'.format(start_datetime)
    run_name = get_run_name(config['dataset_ls'])

    if config['adjust']:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec+'_{}'.format(experiment_date))
    else:
        meta_save_path = os.path.join(config['run_dir'], '{}_shot'.format(config['num_shot']), run_name, run_spec)
    
    ensure_path(meta_save_path)
    
    #* save the config file
    with open(os.path.join(meta_save_path, 'config_{}.json'.format(run_spec)), 'w') as outfile:
        outfile.write(json.dumps(config, indent=4))

    #* set the tensorboard writer
    writer = SummaryWriter(os.path.join(meta_save_path, 'log'))
    result_dir = os.path.join(meta_save_path, 'result')
    if not os.path.exists(result_dir):
        os.makedirs(result_dir, exist_ok=True)
    

    # create the model
    if config['encoder'] == 'resnet12':
        classifier = meta_sparse.ResNet().to(config['device'])
    else:
        classifier = meta_sparse.MetaConvModel(in_channels=3,
                                    out_features=config['num_way'],
                                    hidden_size=64,
                                    feature_size=1600,
                                    bias=True).to(config['device'])

    parameters = [{'params': classifier.parameters(), 'lr': 0.001},]

    # init the parameter to be updated in the inner adaption
    inner_loop_params = []
    for name, params in classifier.named_parameters():
        num_params =+ np.prod(params.shape)

        if "norm" in name:
            continue
        inner_loop_params.append(name)

    # MASK plus
    mask_names = []
    for name,params in classifier.named_parameters():
        if "conv.weight" in name and name in inner_loop_params:
            mask_names.append(name)
    #the number of channels for each layer, in order
    mask_channels = [3,64,64,64]
    mask_plus = meta_sparse.GradientMaskPlus(config,
                                layer_names=mask_names,
                                layer_sizes=mask_channels,
                                feature_size=1600).to(config['device'])
    
    # MASK
    mask_names = []
    shapes = []
    for name, params in classifier.named_parameters():
        if name in inner_loop_params:
            mask_names.append(name)
            shapes.append(params.shape)

    mask = meta_sparse.GradientMask(config, weight_names= mask_names,
                                    weight_shapes=shapes, 
                                    mask_plus=mask_plus
                                    ).to(config['device'])

    # NOTE: The parameters of mask_plus are contained in mask
    if config['optimizer_mask'] == "ADAM" or config['optimizer_mask'] == "Adam":
        optimizer_mask=torch.optim.Adam(mask.parameters(), config['metaSparse_args']['mask_lr'])

    else:
        optimizer_mask=torch.optim.SGD(mask.parameters(), config['metaSparse_args']['mask_lr'], 
                                            momentum=0.9,
                                            nesterov=(0.9 > 0.0))
    print("\nMask optimizer:", optimizer_mask)

    # optimization
    # OPTIMIZER
    if config['optimizer_theta'] == "ADAM" or config['optimizer_theta'] == "Adam":
        optimizer_theta=torch.optim.Adam(parameters, config['metaSparse_args']['lr'])

    else:
        optimizer_theta=torch.optim.SGD(parameters, config['metaSparse_args']['lr'], 
                                                momentum=config['momentum'],
                                                nesterov=(config['momentum'] > 0.0))
    print("\nTheta optimizer:", optimizer_theta)

    # MAML object
    loss_function = torch.nn.CrossEntropyLoss().to(config['device'])
    metalearner=MAML(classifier, optimizer_theta=optimizer_theta,
                    optimizer_mask=optimizer_mask, 
                    loss_function=loss_function, 
                    inner_loop_params=inner_loop_params,
                    config=config)

    if config['gradient_mask'] or config['weight_mask']:
        metalearner.mask=mask
        # if args.gradient_mask_plus:
        #     metalearner.mask=mask
    print("\nStart training ... \n")

    # load the data
    #* define the training dataset
    transformation = transforms.Compose(
        enlist_transformation(img_resize=config['img_resize'], is_grayscale=config['is_grayscale'],
                            device=config['device'], img_normalise=config['img_normalise'])
    )    
    num_dataset_to_run = len(config['dataset_ls']) if config['num_dataset_to_run'] == "all" \
                        else config['num_dataset_to_run']
    
    num_dataset_to_test = len(config['test_dataset_ls']) if config['num_dataset_to_test'] == "all" \
                        else config['num_dataset_to_test']
    trainloaders = []
    testloaders = []
    task_per_itr = int(config['batch_size']/config['num_dataset_to_run'])

    for task_idx, task in enumerate(config['dataset_ls'][:num_dataset_to_run], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for training dataset
        train_task_list = np.load(os.path.join(split_dir, 'metatrain.npy'))
        trainset = FewShotImageDataset(task_list=train_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} trainset'.format(task))
        trainsampler = SuppQueryBatchSampler(
            dataset=trainset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=task_per_itr
        )
        trainloader = DataLoader(trainset, batch_sampler=trainsampler)
        trainloaders.append(trainloader)

    for task_idx, task in enumerate(config['test_dataset_ls'][:num_dataset_to_test], 0):
        
        split_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), task)
        
        #* for testing dataset
        test_task_list = np.load(os.path.join(split_dir, 'metatest.npy'))
        testset = FewShotImageDataset(task_list=test_task_list, transform=transformation, 
                                device=config['device'], task_name = task, verbose='{} testset'.format(task))
        testsampler = SuppQueryBatchSampler(
            dataset=testset, num_way=config['num_way'], num_shot=config['num_shot'],
            num_query_per_cls=config['num_query_per_cls'], num_task=config['num_val_task']
        )
        testloader = DataLoader(testset, batch_sampler=testsampler)
        testloaders.append(testloader)

    # training_loop(args, metalearner, trainloaders, testloaders)

    best_acc_val = 0
    best_acc_epoch = config['num_epoch']
    best_acc = 0.
    mean_sparsity_best = 0.

    test_loss_mean_dict_list = []
    test_loss_std_dict_list = []
    test_acc_mean_dict_list = []
    test_acc_std_dict_list = []

    for i in range(config['num_dataset_to_test']):
        test_loss_mean_dict_list.append(dict())
        test_loss_std_dict_list.append(dict())
        test_acc_mean_dict_list.append(dict())
        test_acc_std_dict_list.append(dict())

    checkpoint = 0.0
    max_index = 0
    for epoch in trange(config['num_epoch']):
        
        ms = metalearner.train(trainloaders, 4, epoch, writer=writer)


        if (epoch+1)%1000 == 0 or epoch==0:
            test_acc_mean_list, test_acc_std_list  = metalearner.evaluate(testloaders, 1, epoch, writer=writer)
            for i in range(config['num_dataset_to_test']):
                # test_loss_mean_dict_list[i][itr] = test_loss_mean_list[i]
                # test_loss_std_dict_list[i][itr] = test_loss_std_list[i]
                test_acc_mean_dict_list[i][epoch] = test_acc_mean_list[i]
                test_acc_std_dict_list[i][epoch] = test_acc_std_list[i]
                print('{}: acc_mean={} and acc_std={}'.format(config['test_dataset_ls'][i], test_acc_mean_list[i], test_acc_std_list[i]))

            print('Average Performance: acc_mean={} and acc_std={}'.format(np.mean(test_acc_mean_list), np.mean(test_acc_std_list)))
            if np.mean(test_acc_mean_list) > checkpoint:
                checkpoint = np.mean(test_acc_mean_list)
                max_index = epoch
                torch.save(classifier.state_dict(), f=os.path.join(result_dir, 'step{}.pt'.format(epoch)))

    max_avg_mean = 0.0
    max_avg_std = 0.0
    for i in range(config['num_dataset_to_test']):
        max_avg_mean += test_acc_mean_dict_list[i][max_index]
        max_avg_std += test_acc_std_dict_list[i][max_index]
        print("the best performance of {} is {:.2f} + {:.2f}".format(config['test_dataset_ls'][i], test_acc_mean_dict_list[i][max_index]*100, test_acc_std_dict_list[i][max_index]*100))

    print("the best average performance is {:.2f} + {:.2f}".format(max_avg_mean*100/config['num_dataset_to_test'], max_avg_std*100/config['num_dataset_to_test']))

    writer.close()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', 
                        help='configuration file')
    parser.add_argument('--efficient', 
                        help='if True, enables gradient checkpointing',
                        action='store_true',
                        default=False)
    parser.add_argument('--seq_task',
                        help='if True, the mini-batch consists of tasks only from one task distribution',
                        default=False)
    parser.add_argument('--adjust',
                        help='if True, the saved file will be create based on the experimental time',
                        default=False)
    args = parser.parse_args()
    # load config file
    jsonfile = open(str(args.config_path))
    config = json.loads(jsonfile.read())
    config['adjust'] = args.adjust
    config['seq_task'] = args.seq_task

    run_spec = os.path.splitext(os.path.split(args.config_path)[-1])[0]

    main(config, run_spec)