from collections import OrderedDict
from statistics import mode
import numpy as np
import time

from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F

import util


def meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr_index, seq_task=False, writer=None):

    model.train()

    #* random choice the sequence
    if not seq_task:
        seq_index = np.random.choice(len(trainloaders), len(trainloaders), replace=False).tolist()
        # print("{}:{} dataset".format(itr_index, seq_index))
    else:
        one_index = np.random.choice(len(trainloaders), len(1), replace=False)[0]
        seq_index = np.tile(one_index, len(trainloaders)).tolist()
    
    loss_itmes = torch.tensor(0.).to(config['device'])
    acc_items = []
    
    for index_ in seq_index: 
        
        acc_one_dis = []
        # loss_one_dis = torch.tensor(0.).to(config['device'])

        for batch_idx, (images, labels) in enumerate(trainloaders[index_], 0):
            loss_one_dis = torch.tensor(0.).to(config['device'])
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx], imgs[supp_idx:] # sample=way*shot, channel, imagesize, imagesize
            support_lbl, query_lbl = lbls[:supp_idx], lbls[supp_idx:]# sample=way*shot

            # print(support_lbl)
            # print(query_lbl)
            # print(aaa)
            model.n_query = config['num_query_per_cls']
            loss, acc = model.forward_loss(support_img, query_img, query_lbl)

            loss.backward()

            loss_itmes += loss.item()

            acc_one_dis.append(acc)

            '''
            reset
            '''
            # whether to reset the classifier
            if inner_args['reset_classifier']:
                model.reset_classifier()
            
            # # compute the prediction and loss
            # logits, _ = model(support_img, query_img, support_lbl, inner_args, meta_train=True)
            # logits = logits.flatten(0,1)
            # query_lbl = query_lbl.flatten()

            # pred = torch.argmax(logits, dim=1)
            # acc = util.compute_acc(pred, query_lbl)
            # loss = F.cross_entropy(logits, query_lbl)

            # acc_one_dis.append(acc)
            # loss_one_dis += loss
            
            # loss_one_dis.backward()
            # loss_itmes += loss_one_dis.item()
            # loss_one_dis = torch.tensor(0.).to(config['device'])

        acc_items.append(np.mean(acc_one_dis))
        
        # print('the mean of grad of std:{}'.format(model.prompt_cov.grad.mean()))
        # loss_itmes += (loss_one_dis/config['num_task_per_itr'])
    
    #! update the parameters
    
    
    
    # loss_itmes.backward()
    params = OrderedDict(model.named_parameters())
    # for name, param in model.named_parameters():
    #     if name == "prompt_mean" or name == "prompt_cov":
    #     # if name =="prompt_mean":
    #         print(name)
    #         print(param.grad)
    #     else:
    #         continue
        
    # print(aaa)

    if config['optimizer'] == "prompt":
        for param in optimizer.param_groups[0]['params']:
            nn.utils.clip_grad_value_(param, config['optimizer_args']['prompt']['clip'])
    else:
        for param in optimizer.param_groups[0]['params']:
            nn.utils.clip_grad_value_(param, 5)
    
    # for name, param in model.named_parameters():
        
    #     print(param)
    #     print(param.grad)
    #     print(aaa)
        
    
    optimizer.step()
    
    # for name, param in model.named_parameters():
    #     if name == "prompt_mean":
    #         print(param.grad)
    #         print(param)
    #     else:
    #         continue   

    # if itr_index == 100: 
    #     print(aaa)
    optimizer.zero_grad()

    #* save the performance
    writer.add_scalar(
        tag='loss_meta_train', scalar_value=loss_itmes.item()/config['num_dataset_to_run'], global_step=itr_index
    )

    writer.add_scalar(
        tag='accuracy_meta_train', scalar_value=np.mean(acc_items), global_step=itr_index
    )

    print('\n epoch {} : loss {} and acc {}'.format(itr_index, loss_itmes.item()/config['num_dataset_to_run'], np.mean(acc_items)))

    return loss_itmes.item()/config['num_dataset_to_run'], np.mean(acc_items)