from collections import OrderedDict
from statistics import mode
import numpy as np
import time

from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F

import util
import meta_test


def meta_train(model, optim, optim_lr, trainloaders, inner_args, config, itr_index, seq_task=False, writer=None):

    model.model.train()

    optim.zero_grad()
    optim_lr.zero_grad()

    #* random choice the sequence
    if not seq_task:
        seq_index = np.random.choice(len(trainloaders), len(trainloaders), replace=False).tolist()
        # print("{}:{} dataset".format(itr_index, seq_index))
    else:
        one_index = np.random.choice(len(trainloaders), len(1), replace=False)[0]
        seq_index = np.tile(one_index, len(trainloaders)).tolist()
    
    loss_itmes = torch.tensor(0.).to(config['device'])
    acc_items = []
    
    for index_ in seq_index: 
        
        acc_one_dis = []
        # loss_one_dis = torch.tensor(0.).to(config['device'])

        for batch_idx, (images, labels) in enumerate(trainloaders[index_], 0):
            loss_one_dis = torch.tensor(0.).to(config['device'])
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx], imgs[supp_idx:]
            support_lbl, query_lbl = lbls[:supp_idx].unsqueeze(dim=0), lbls[supp_idx:].unsqueeze(dim=0)
            
            #* compute the loss
            val_loss, val_acc, kl_div, encoder_penalty, orthogonality_penalty = model.run_batch(support_img, query_img, support_lbl, query_lbl, True)

            acc_one_dis.append(val_acc.item())

            val_loss.backward()
        
            loss_itmes += val_loss.item()

        acc_items.append(np.mean(acc_one_dis))
        
        # print('the mean of grad of std:{}'.format(model.prompt_cov.grad.mean()))
        # loss_itmes += (loss_one_dis/config['num_task_per_itr'])
    
    # update the parameters
    nn.utils.clip_grad_value_(model.model.parameters(), config['clip_value'])
    nn.utils.clip_grad_norm_(model.model.parameters(), config['clip_value'])
    optim.step()
    optim_lr.step()

    #* save the performance
    writer.add_scalar(
        tag='loss_meta_train', scalar_value=loss_itmes.item()/config['num_dataset_to_run'], global_step=itr_index
    )

    writer.add_scalar(
        tag='accuracy_meta_train', scalar_value=np.mean(acc_items), global_step=itr_index
    )

    return loss_itmes.item()/config['num_dataset_to_run'], np.mean(acc_items)