from collections import OrderedDict
from statistics import mode
import numpy as np
import time

from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F

import util
import meta_test


def meta_train(model, optimizer, lr_scheduler, trainloaders, inner_args, config, itr_index, seq_task=False, writer=None):

    model.train()

    #* random choice the sequence
    if not seq_task:
        seq_index = np.random.choice(len(trainloaders), len(trainloaders), replace=False).tolist()
        # print("{}:{} dataset".format(itr_index, seq_index))
    else:
        one_index = np.random.choice(len(trainloaders), len(1), replace=False)[0]
        seq_index = np.tile(one_index, len(trainloaders)).tolist()
    
    loss_itmes = torch.tensor(0.).to(config['device'])
    acc_items = []
    
    for index_ in seq_index:
        
        acc_one_dis = []
        # loss_one_dis = torch.tensor(0.).to(config['device'])

        for batch_idx, (images, labels) in enumerate(trainloaders[index_], 0):
            loss_one_dis = torch.tensor(0.).to(config['device'])
            imgs = images.to(config['device'])
            lbls = labels.to(config['device'])
            
            supp_idx = config['num_way']*config['num_shot']
            support_img, query_img = imgs[:supp_idx].unsqueeze(dim=0), imgs[supp_idx:].unsqueeze(dim=0)
            support_lbl, query_lbl = lbls[:supp_idx].unsqueeze(dim=0), lbls[supp_idx:].unsqueeze(dim=0)


            # whether to reset the classifier
            if inner_args['reset_classifier']:
                model.reset_classifier()
            
            #* compute the prediction and loss
            if not model.is_prompt:
                logits, _ = model(support_img, query_img, support_lbl, inner_args, meta_train=True)
                logits = logits.flatten(0,1)
                query_lbl = query_lbl.flatten()

                pred = torch.argmax(logits, dim=1)
                acc = util.compute_acc(pred, query_lbl)
                loss = F.cross_entropy(logits, query_lbl)

                acc_one_dis.append(acc)
                loss_one_dis += loss
            
            else:
                acc_sample = []
                # for sampling
                for i in range(config['prompt_args']['sample_num']):
                    logits, prompt_smaple = model(support_img, query_img, support_lbl, inner_args, meta_train=True)

                    if torch.isnan(prompt_smaple).any():
                        print('mean')
                        print(model.prompt_mean)
                        print('cov')
                        print(model.prompt_cov)
                        
                    logits = logits.flatten(0,1)
                    query_lbl = query_lbl.flatten()

                    pred = torch.argmax(logits, dim=1)
                    acc = util.compute_acc(pred, query_lbl)
                    loss = F.cross_entropy(logits, query_lbl)
                    
                    if model.is_bayesian:
                        kl_loss = util.kl_divergence_gaussian(model.prompt_mean, torch.log(1.+model.prompt_cov.exp()), input_=prompt_smaple)
                        sample_num = config['num_query_per_cls']*config['num_way']
                        #print(kl_loss)
                        kl_loss /= sample_num*config['prompt_args']['dim']
                        #print("loss:{} \n kl_loss:{}".format(loss, kl_loss))
                        loss += config['prompt_args']['kl_weight']*kl_loss
            
                    loss_one_dis += loss
                    acc_sample.append(acc)
                
                acc_one_dis.append(np.mean(acc_sample))

            if model.is_prompt:
                loss_one_dis /= config['prompt_args']['sample_num']
    
            loss_one_dis.backward()
            print('1')
            loss_itmes += loss_one_dis.item()
            loss_one_dis = torch.tensor(0.).to(config['device'])

        acc_items.append(np.mean(acc_one_dis))
        
        # print('the mean of grad of std:{}'.format(model.prompt_cov.grad.mean()))
        # loss_itmes += (loss_one_dis/config['num_task_per_itr'])
    
    #! update the parameters
    
    
    
    # loss_itmes.backward()
    # params = OrderedDict(model.named_parameters())
    # for name, param in model.named_parameters():
    #     if name == "reparam_emb.weight":
    #     # if name =="prompt_mean":
    #         print(name)
    #         print(param.grad)
    #     else:
    #         continue
    # print(aaa)
        
    # print(aaa)

    if config['optimizer'] == "prompt":
        for param in optimizer.param_groups[0]['params']:
            nn.utils.clip_grad_value_(param, config['optimizer_args']['prompt']['clip'])
    else:
        for param in optimizer.param_groups[0]['params']:
            nn.utils.clip_grad_value_(param, 20)
    
    # for name, param in model.named_parameters():
    #     if name == "prompt_mean":
    #         print(param)
    #         print(param.grad)
    #     else:
    #         continue
    
    optimizer.step()
    
    # for name, param in model.named_parameters():
    #     if name == "prompt_mean":
    #         print(param.grad)
    #         print(param)
    #     else:
    #         continue   

    # if itr_index == 100: 
    #     print(aaa)
    optimizer.zero_grad()

    #* save the performance
    writer.add_scalar(
        tag='loss_meta_train', scalar_value=loss_itmes.item()/config['num_dataset_to_run'], global_step=itr_index
    )

    writer.add_scalar(
        tag='accuracy_meta_train', scalar_value=np.mean(acc_items), global_step=itr_index
    )

    return loss_itmes.item()/config['num_dataset_to_run'], np.mean(acc_items)