import argparse
import torch
import numpy as np
import random
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from copy import deepcopy
from torch.utils.data.sampler import SubsetRandomSampler
# from sample import *
import matplotlib.pyplot as plt
from utils import *
from train import Trainer_pro, Trainer_att
import torchvision.datasets as datasets
# from model import *
from torch.utils.data import random_split
from pre_trained.cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from pre_trained.cifar10_models.resnet import resnet18, resnet34
from pre_trained.cifar10_models.mobilenetv2 import mobilenet_v2


def maml(spt_list,  qry_list, net, init_net_state, mask, arg, device):
    """
    :param spt_list:   list of support dataloader (length of task_num)
    :param qry_list:   list of query dataloader (length of task_num)
    :net: degraded model
    :meta optimizer: update the degraded model (upper optimization)
    :return:
    """
    task_num = len(spt_list)
    querysz = arg.batch_size


    update_step  = arg.update_step # multiple gradient updates, if set to 1 is the single update

    losses_q = [0 for _ in range(update_step + 1)]  # losses_q[i] is the loss on step i
    corrects = [0 for _ in range(update_step + 1)]
    weight_task = [1/task_num for _ in range(task_num)] # take the average of each task loss

    criterion = nn.CrossEntropyLoss()



    meta_optim = optim.Adam(net.parameters(), lr=arg.maml_lr)


    grad_qry_sum = None
    for i in range(task_num):
        # initialize net for each task
    
        net.load_state_dict(init_net_state)

        # before first update
        correct, loss_q,_ = query_eval(net, device, qry_list[i])
        corrects[0] += correct
        losses_q[0] += loss_q


        if task_num > 1 and arg.weighted_flag == 1:
            weight_task = [(1-arg.source_weight) / (task_num-1) for i in range(task_num)]
            weight_task[0] = arg.source_weight # assign more weights to the source task

        # run the i-th task
        spt = spt_list[i]
        for k, (inputs, targets) in enumerate(spt):
            for step in range(arg.update_step):
               
                inputs, targets = inputs.to(device), targets.to(device)
                # forward
                meta_optim.zero_grad()
                outputs = net(inputs,0)
                # Use cross entropy to calculate loss for output and labels
                loss = criterion(outputs, targets)
              
                loss.backward()
               

                meta_optim.step()
              
            correct, loss_q, grad_qry = query_eval(net, device, qry_list[i])
            corrects[k+1] += correct
            losses_q[k+1] += weight_task[i] * loss_q
        
            break

           
           

      
        weighted_grad = [torch.tensor(weight_task[i]) * g for g in grad_qry]
        if grad_qry_sum == None:
            grad_qry_sum = weighted_grad
        else:
            grad_qry_sum = [grad_qry_sum[idx] + weighted_grad[idx] for idx in range(len(weighted_grad))]
    
    net.load_state_dict(init_net_state)
 
    for (name, para), grad in zip(net.named_parameters(), grad_qry_sum):
      
        grad *= mask[name]
        para.data += arg.maml_lr * grad



   
    accs = np.array(corrects) / (querysz * task_num)
    updated_state = deepcopy(net.state_dict())


    return accs, updated_state
