import argparse
import torch
import numpy as np
import random
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from copy import deepcopy
from torch.utils.data.sampler import SubsetRandomSampler
# from sample import *
import matplotlib.pyplot as plt
from utils import *
from train import *
import torchvision.datasets as datasets
from maml import maml
from naive import naive
from dset import *
from torch.utils.data import random_split
from pre_trained.cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from pre_trained.cifar10_models.resnet import resnet18, resnet34
from pre_trained.cifar10_models.mobilenetv2 import mobilenet_v2
import logging
import os
# import wandb

parser = argparse.ArgumentParser('model')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--max_val_iter', type=int, default=3)
parser.add_argument('--net', type=int, default=0) 
parser.add_argument('--weighted_flag', type=int, default=1) 
parser.add_argument('--source_weight', type=float, default=0.1) 
parser.add_argument('--tar', type=int, default=0) 
parser.add_argument('--update_step', type=int, default=1)
parser.add_argument('--pro_lr', type=float, default=0.01)
parser.add_argument('--maml_lr', type=float, default=0.0005)
parser.add_argument('--att_lr', type=float, default=0.001) 
parser.add_argument('--num_epoch_meta', type=int, default=20)
parser.add_argument('--num_epoch_tr', type=int, default=50)
parser.add_argument('--num_epoch_naive', type=int, default=30)
parser.add_argument('--num_epoch_att', type=int, default=50)
parser.add_argument('--num_auxi', type=int, default=10)
parser.add_argument('--num_class_sample', type=int, default=100)
parser.add_argument('--att_data', type=int, default=0.05) # ratio of full training set
parser.add_argument('--PATH', type=str, default='./model/tmp.pth')
parser.add_argument('--PATH_vict', type=str, default='./model/svhn_vgg11.pth')
parser.add_argument('--out_dir', default='./log_att_data_test',
                    help='directory for saving results')
parser.add_argument('--source', type=str, default='us')
parser.add_argument('--target', type=str, default='mn')
parser.add_argument('--min_w', type=float, default=-1)
parser.add_argument('--max_w', type=float, default=1)
parser.add_argument('--seed_val', type=int, default=66)
parser.add_argument('--warm_up', type=bool, default=True)
parser.add_argument('--maml_flag', type=int, default=1) 
parser.add_argument('--verbose_flag', type=int, default=0) 
args = parser.parse_args()

logger = logging.getLogger(__name__)
log_format = '[%(asctime)s] - %(message)s'
date_format='%Y/%m/%d %H:%M:%S'
formatter = logging.Formatter(log_format, date_format)
logger.setLevel(logging.INFO)

def main(args, device, net, victim_state_dict, layer_list, mask, trainloader_pro, testloader_pro, trainloader_att, testloader_att, spt_list, qry_list):

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)
    logfile = os.path.join(args.out_dir, args.logfile)

    file_handler = logging.FileHandler(logfile)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(args)

    
    # %% ========= load net ==========
    logger.info('=========================load net================================')
    net.load_state_dict(victim_state_dict)
    # layer_violin(net, plot_flag=False)

    base_accuracy = inference(net, device, testloader_pro)
    logger.info('Original accuracy of the victim model (baseline): {:.3%}'.format(base_accuracy))

    # victim_state_dict = deepcopy(net.state_dict())

    dir_use_acc = []
    trans_acc = []
    meta_acc = []



    if args.maml_flag == 1:
        # print('============ Unprotected =======')
        logger.info('============ Unprotected =======')
        unpro_accuracy = inference(net, device, testloader_att)
        dir_use_acc.append(unpro_accuracy)
        logger.info('Direct use without protection: {:.3%}'.format(unpro_accuracy))
        
        net.load_state_dict(victim_state_dict)
        unpro_tran_acc, _ = Trainer_att(args, net, trainloader_att, testloader_att, device)
        logger.info('The transfer results without protection: {:.3%}'.format(unpro_tran_acc))
        



    
    

    # =============== implement maml ===============
    logger.info('\n========== DeTrans =========')
    # warm up
    if args.warm_up == True:
        net.load_state_dict(victim_state_dict)
        _, warm_state = Trainer_pro(args, net, trainloader_pro, testloader_pro, device, mask)
        
        updated_state = deepcopy(warm_state) # the state_dict of the degraded model (from the base protect)
    else:
        updated_state = victim_state_dict
    
    acc_record = 1
    val_iter = 0
    for i in range(args.num_epoch_meta):

        accs, updated_state = maml(spt_list,  qry_list, net, updated_state, mask, args, device)       
        net.load_state_dict(updated_state)

        maml_tran_acc, _ = Trainer_att(args, net, trainloader_att, testloader_att, device)
        logger.info('The transfer result with Detrans protection : {:.3%}'.format(maml_tran_acc))

        if maml_tran_acc < acc_record:
            acc_record = maml_tran_acc
            val_iter = 0
            new_state = deepcopy(updated_state)
            torch.save(new_state, args.PATH)
        else:
            val_iter = val_iter + 1
        if val_iter == args.max_val_iter:
            print("Finetuning accuracy did not improve for the last {} validation runs. Early stopping..."
                  .format(args.max_val_iter))
            break
    
    trans_acc.append(maml_tran_acc)


    logger.info('DeTrans protection (cross-domain): {:.3%}'.format(acc_record))
    net.load_state_dict(new_state)
    maml_src_acc  = inference(net, device, testloader_pro)
    logger.info('DeTrans protection (direct source): {:.3%}'.format(maml_src_acc))


    for iter in range(3):
        net.load_state_dict(new_state)
        maml_tran_acc, _ = Trainer_att(args, net, trainloader_att, testloader_att, device)
        logger.info(f'loop Finetuning: | Acc: {maml_tran_acc:.5f}')
    
    # close the log file
    file_handler.close()
    # remove the handler from the logger. The default behavior is to pop out      
    # the last added one, which is the file_handler we just added in the 
    # beginning of this iteration.
    logger.handlers.pop() 


    
if __name__ == '__main__':
    # fix seed for reproduction
    setup_seed(args.seed_val)

    # set up devices
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        print("Using the GPU!")
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
    else:
        print("WARNING: Could not find GPU! Use CPU only")

    src = ['us']
    for name in src:
        args.source = name
        if name == 'mn':
            if args.net ==0:
                args.PATH_vict = './model/mnist_vgg11_wo_pre.pth'
            if args.net == 1:
                args.PATH_vict = './model/mnist_res18.pth'
            tar = ['us','sv']
        if name == 'us':
            if args.net ==0:
                args.PATH_vict = './model/usps_vgg11.pth'
            if args.net == 1:
                args.PATH_vict = './model/usps_res18.pth'
            tar = ['mn']
        if name == 'sv':
            if args.net ==0:
                args.PATH_vict = './model/svhn_vgg11.pth'
            if args.net == 1:
                args.PATH_vict = './model/svhn_res18.pth'
            tar = ['mn','us']

    # net_list = [0,1]
    # for n in net_list:
    #     args.net = n
        # ========= load net ==============
        if args.net == 0:
            layer_list = [64, 128, 256, 256, 512, 512, 512, 512, 4096, 4096, 10]  # vgg_11
            net = vgg11_bn(pretrained=True)
            if args.source != 'cf10':
                net.features[0] = nn.Conv2d(1, 64, kernel_size=3, padding=1)
                victim_state_dict = torch.load(args.PATH_vict)

            else:
                victim_state_dict = deepcopy(net.state_dict())


        if args.net == 1:
            layer_list = [64, 64, 64, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512,
                          10]  # resnet_18
            net = resnet18(pretrained=True)
            if args.source != 'cf10':
                net.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
                victim_state_dict = torch.load(args.PATH_vict)
                net.load_state_dict(victim_state_dict)
            else:
                victim_state_dict = deepcopy(net.state_dict())

        # tar = ['mn','us']
        for t in range(len(tar)):
            print('\n !!========= source domain = %s, target domain = %s =========\n'%(name, tar[t]))
            args.target = tar[t]

        # if True:
        #     print('\n !!========= net = %s =========\n' % n)
            #

            net.load_state_dict(victim_state_dict)
            net.to(device)



            # ===== Modify source and target domain !
            src_trainset, src_testset = load_data(args.source)
            # src_trainset_aug, src_testset_aug = load_data(args.source+'_aug')

            tar_trainset, tar_testset = load_data(args.target)
            # for cf-100
            if args.target == 'cf100':
                # cls_idx = np.random.choice(100, 10).tolist()
                cls_idx = np.arange(10).tolist()
                tar_trainset = sub_dataset(tar_trainset, cls_idx, 10, args.num_class_sample, True)
                tar_testset = sub_dataset(tar_testset, cls_idx, 10, args.num_class_sample, False)


            # ======= End of modification ======

            src_trainloader = torch.utils.data.DataLoader(src_trainset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True)
            src_testloader = torch.utils.data.DataLoader(src_testset,
                                                       batch_size=args.batch_size,
                                                       shuffle=False)
            tar_testloader = torch.utils.data.DataLoader(tar_testset,
                                                     batch_size=args.batch_size,
                                                     shuffle=False)


            if args.tar == 0:
                spt_list, qry_list = auxilary(args.num_auxi, 10, args.num_class_sample, args.batch_size, args.source+'_aug')
            else:
                spt_list, qry_list =[],
            if args.weighted_flag ==1:
                print('Add source domain...')
                spt_list.insert(0, src_trainloader)
                qry_list.insert(0, src_testloader)

            val_size = int(args.att_data * len(tar_trainset))
            train_size = len(tar_trainset) - val_size


            # =========== generate mask =============
            layer_filters = alpha_select(qry_list, net, victim_state_dict, device)

            # layer_filters = [np.random.choice(i) for i in layer_list] # random filter

            print('layer_filters = %s' % (layer_filters))
            mask = mask_layer(layer_filters, net)

            # seed = np.random.randint(100, size=10)
            seed = [66,33,88]
            print('seed = %s'%seed)

            args.logfile = 'output_'+args.source+'_'+args.target+'.log'
            # train_ds, val_ds = random_split(tar_trainset, [train_size, val_size])
            #
            # trainloader_att = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

            for i in range(1):
                print('\n*********** Trial %d **************' %(i))
                setup_seed(seed[i])
                _, val_ds = random_split(tar_trainset, [train_size, val_size])
                # print('len_val_ds',len(val_ds))

                trainloader_att = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

                main(args, device, net, victim_state_dict, layer_list, mask, src_trainloader, src_testloader, trainloader_att, tar_testloader, spt_list, qry_list)


