from cgi import test
from collections import OrderedDict
from statistics import mode
import numpy as np
import time

from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F

import util
import meta_test


def meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr_index, seq_task=False, writer=None):

    model.train()

    #* random choice the sequence
    if not seq_task:
        seq_index = np.random.choice(len(trainloaders), len(trainloaders), replace=False).tolist()
        print("{}:{} dataset".format(itr_index, seq_index))
    else:
        one_index = np.random.choice(len(trainloaders), len(1), replace=False)[0]
        seq_index = np.tile(one_index, len(trainloaders)).tolist()
    
    loss_itmes = torch.tensor(0.).to(config['device'])
    acc_items = []
    
    for index_ in seq_index: 
        
        acc_one_dis = []
        loss_one_dis = torch.tensor(0.).to(config['device'])

        for batch_idx, (images, labels) in enumerate(trainloaders[index_], 0):
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx], imgs[supp_idx:]
            support_lbl, query_lbl = lbls[:supp_idx], lbls[supp_idx:]

            # whether to reset the classifier
            if inner_args['reset_classifier']:
                model.reset_classifier()
            
            #* the inner update
            a_dict = inner_update(model, support_img, support_lbl)
            
            #* compute the meta loss
            meta_loss = 0
                
            Y_meta_hat = model(query_img, a_dict)
            loss_t = F.cross_entropy(Y_meta_hat, query_lbl)
            pred = torch.argmax(Y_meta_hat, dim=1)
            acc = util.compute_acc(pred, query_lbl)

            meta_loss += loss_t
            acc_one_dis.append(acc)

            # acc_one_dis.append(acc)
            loss_one_dis += meta_loss
            

        acc_items.append(np.mean(acc_one_dis))
        loss_one_dis.backward()
        # loss_itmes += (loss_one_dis/config['num_task_per_itr'])
    
    #! update the parameters
    
    # loss_itmes.backward()
    params = OrderedDict(model.named_parameters())
    
    for param in optimizer.param_groups[0]['params']:
        nn.utils.clip_grad_value_(param, 10)
    
    optimizer.step()
    optimizer.zero_grad()

    #* save the performance
    writer.add_scalar(
        tag='loss_meta_train', scalar_value=loss_itmes.item()/config['num_dataset_to_run'], global_step=itr_index
    )

    writer.add_scalar(
        tag='accuracy_meta_train', scalar_value=0., global_step=itr_index
    )

    return loss_itmes.item()/config['num_dataset_to_run'], np.mean(acc_items)


def inner_update(model, support_img, support_lbl):

    logits = model(support_img)
    loss = F.cross_entropy(logits, support_lbl)

    # clear previous gradients, compute gradients of all variables wrt loss
    def zero_grad(params):
        for p in params:
            if p.grad is not None:
                p.grad.zero_()

    zero_grad(model.parameters())
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)

    adapted_state_dict = model.cloned_state_dict()
    adapted_params = OrderedDict()
    for (key, val), grad in zip(model.named_parameters(), grads):
        # NOTE Here Meta-SGD is different from naive MAML
        # Also we only need single update of inner gradient update
        task_lr = model.task_lr[key]
        adapted_params[key] = val - task_lr * grad
        adapted_state_dict[key] = adapted_params[key]
    
    return adapted_state_dict
