#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms, models
import torch
# from torchvision.models import models.ResNet18_Weights
# from utils.sampling import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_alpha, mnist_noniid_unequal
from utils.sampling_sep import mnist_iid, mnist_noniid, cifar_iid, cifar_noniid, mnist_noniid_alpha, mnist_noniid_unequal
from utils.sampling_sep import cifar_iid_unseen_train, cifar_noniid_alpha_unseen_train, cifar_noniid_alpha_unseen_test

from utils.options import args_parser
from models.Update import LocalUpdate, LocalUpdate_q_ourpre, LocalUpdate_meta_q_our, LocalUpdate_meta_s, LocalUpdate_meta_q_our_3, ServerUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg, FedAvg_ourpre, meta_agg
from models.test import test_img , test_img_byclients, test_img_byclients_for_meta
import pdb
import random
from tqdm import tqdm
import os
import math
import statistics
from utils.subdataset import custom_subset
from torch.autograd import Variable

random.seed(886)
np.random.rand(886)

if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    if args.dataset == 'cifarhund':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR100('data/cifarhund', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR100('data/cifarhund', train=False, download=True, transform=trans_cifar)

        label = [*range(0, 100, 1)]
        random.shuffle(label)
        label_08 = label[:80]
        label_02 = label[80:]

        train_08_id = [idx for idx, target in enumerate(dataset_train.targets) if target in label_08]
        train_02_id = [idx for idx, target in enumerate(dataset_train.targets) if target in label_02]

        test_08_id = [idx for idx, target in enumerate(dataset_test.targets) if target in label_08]
        test_02_id = [idx for idx, target in enumerate(dataset_test.targets) if target in label_02]
        
        ## partition for hybrid setting
        random.shuffle(train_08_id)
        random.shuffle(test_08_id)
        train_08_id_client = train_08_id[ : int(0.95 * len(train_08_id))]
        train_08_id_server = train_08_id[int(0.95 * len(train_08_id)) : ]
        test_08_id_client = test_08_id[ : int(0.95 * len(test_08_id))]
        test_08_id_server = test_08_id[int(0.95 * len(test_08_id)) : ]

        train_labels_08_client = [dataset_train.targets[i] for i in train_08_id_client]
        train_labels_08_server = [dataset_train.targets[i] for i in train_08_id_server]
        test_labels_08_client = [dataset_train.targets[i] for i in test_08_id_client]
        test_labels_08_server = [dataset_train.targets[i] for i in test_08_id_server]


        train_labels_02 = [dataset_train.targets[i] for i in train_02_id]
        test_labels_02 = [dataset_test.targets[i] for i in test_02_id]

        train_datasub_08_client = custom_subset(dataset_train, train_08_id_client, train_labels_08_client)
        train_datasub_08_server = custom_subset(dataset_train, train_08_id_server, train_labels_08_server)
        test_datasub_08_client = custom_subset(dataset_train, test_08_id_client, test_labels_08_client)
        test_datasub_08_server  = custom_subset(dataset_train, test_08_id_server, test_labels_08_server)

        train_datasub_02 =custom_subset(dataset_train, train_02_id, train_labels_02)
        test_datasub_02 = custom_subset(dataset_test, test_02_id, test_labels_02)


        concat_train_server =  torch.utils.data.ConcatDataset([train_datasub_08_server, test_datasub_08_server])

        concat_train = torch.utils.data.ConcatDataset([train_datasub_08_client, test_datasub_08_client])
        concat_test = torch.utils.data.ConcatDataset([train_datasub_02, test_datasub_02])  
        if args.iid:
            train_dict, val_dict = cifar_iid_unseen_train(concat_train, args.num_users)
            finetune_train_dict, finetune_test_dict = cifar_iid_unseen_train(concat_test, args.num_users)
        else:
            train_dict, val_dict = cifar_noniid_alpha_unseen_train(concat_train, args.num_users, args.alpha, label_08)
            finetune_train_dict, finetune_test_dict = cifar_noniid_alpha_unseen_train(concat_test, args.num_users, args.alpha, label_02)

    else:
        exit('Error: unrecognized dataset')

    # build model
    
    if args.model == 'resnet' and args.dataset == 'cifarhund':
        if args.pretrain == 0:
            net_glob = models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device)
        else:
            net_glob = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(args.device)
            num_ftrs = net_glob.fc.in_features
            net_glob.fc = torch.nn.Linear(num_ftrs, args.num_classes).to(args.device)
    
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)



    if args.iid:
    
        log_path_log = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_selectC{}_balance{}/log".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.select_users, args.balancer)
        log_path_save = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_selectC{}_balance{}/save".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.select_users, args.balancer)
        if args.clientselection: 
            log_path_log = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_selectC{}_balance{}_select{}/log".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.select_users, args.balancer, args.select_type)
            log_path_save = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_selectC{}_balance{}_select{}/save".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.select_users, args.balancer, args.select_type)

    else:
        log_path_log = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_alpha{}_selectC{}_balance{}/log".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.alpha, args.select_users, args.balancer)
        log_path_save = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_alpha{}_selectC{}_balance{}/save".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.alpha, args.select_users, args.balancer)
        if args.clientselection: 
            log_path_log = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_alpha{}_selectC{}_balance{}_select{}/log".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.alpha, args.select_users, args.balancer, args.select_type)
            log_path_save = "/../log/{}_{}_Pre{}_Client{}_LocalE{}_Round{}_iid{}_alpha{}_selectC{}_balance{}_select{}/save".format(args.dataset, args.model, args.pretrain, args.num_users, args.local_ep, args.epochs, args.iid, args.alpha, args.select_users, args.balancer, args.select_type)
   
    # # Check whether the specified path exists or not
    isExist = os.path.exists(log_path_log)
    if not isExist:
       os.makedirs(log_path_log)
    isExist = os.path.exists(log_path_save)
    if not isExist:
       os.makedirs(log_path_save)
   
    
    
    
    # pre-training

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train_l = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    lowacc_init = 0
    lowloss_init = math.inf
    lowstd_init = math.inf
    
    lowacc = 0
    lowloss = math.inf
    lowstd = math.inf

    lowacc_ = 0
    lowacc_2 = 0
    if args.all_clients: 
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
        w_locals_q = [w_glob for i in range(args.num_users)]
    for iter in tqdm(range(args.epochs)):
        loss_locals = []

        if not args.all_clients:
            w_locals = []
            m = max(int(args.frac * args.num_users), 1)
            idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        else:
            idxs_users = np.arange(args.num_users)

        print("start local training") # this block take time!

        init_acc_train_local = []
        init_loss_train_local = []
        train_acc_list_local = []
        train_loss_list_local = []
        test_acc_list_local = []
        test_loss_list_local = []
        test_acc_ontrain_local = []
        test_loss_ontrain_local = []
        test_acc_ontrain_afteropt_local = []
        test_loss_ontrain_afteropt_local = []
        test_acc_list_local_ = []
        
        test_acc_list_local_2 = []
        test_loss_list_local_2 = []
        # select k clients to do iteration and agg
        new_sets = random.sample(idxs_users.tolist(), args.select_users)



        spt_corrects, qry_corrects = 0, 0
        spt_loss, qry_loss = 0.0, 0.0
        spt_sz, qry_sz = 0, 0
        num_size = []
        solns = []
        s_loss_app, s_correct_app, s_num_sample_app = [], [], []
        # for idx in idxs_users:
        for idx in new_sets:
            net_glob.train()
            local = LocalUpdate_meta_s(args=args, dataset=concat_train, idxs=train_dict[idx])
            
            w, loss , s_loss, s_correct, s_num_sample = local.train(net=copy.deepcopy(net_glob).to(args.device))
            s_loss_app.append(s_loss)
            s_correct_app.append(s_correct)
            s_num_sample_app.append(s_num_sample)
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))

            net_glob.load_state_dict(w)

            net_glob.eval()
            acc_train, loss_train , Correct_train , Len_train = test_img_byclients(net_glob, concat_train, train_dict[idx], args)
            train_acc_list_local.append(acc_train.item()) 
            train_loss_list_local.append(loss_train) 
        
        
        print("finish local training")

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train_l.append(loss_avg)
        
        # # record for analysis
        with open(log_path_log + '/log_train_loss_local_1.txt', 'a') as f:
            f.write(str(train_loss_list_local))
            f.write('\n')
        with open(log_path_log + '/log_train_acc_local_1.txt', 'a') as f:
            f.write(str(train_acc_list_local))
            f.write('\n')
        

        # # update global weights
        # w_glob = FedAvg(w_locals)
        w_glob = FedAvg_ourpre(w_locals, new_sets)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)
        
        # do meta-evaluation
        net_our_meta = copy.deepcopy(net_glob)
        app_loss = []
        
        for idx in new_sets:
             
            net_our_meta.eval()
            acc_test, loss_test, Correct_test, Len_test = test_img_byclients(net_our_meta, concat_train, val_dict[idx], args)
            test_acc_list_local.append(acc_test.item()) 
            test_loss_list_local.append(loss_test)
            app_loss.append(loss_test)
        avg_loss = statistics.mean(app_loss)
        std_loss = statistics.pstdev(app_loss)
        var_loss = std_loss * std_loss
       
        
        val_meanacc = statistics.mean(test_acc_list_local)
        val_stdacc = statistics.pstdev(test_acc_list_local)
        val_meanloss = statistics.mean(test_loss_list_local)
        with open(log_path_log + '/log_eval_acc_local_1.txt', 'a') as f:
            f.write(str(test_acc_list_local))
            f.write('\n')
        with open(log_path_log + '/log_eval_loss_local_1.txt', 'a') as f:
            f.write(str(test_loss_list_local))
            f.write('\n')

        count = 0
        for idx in new_sets:
            net_our_meta.train()
            local_meta = LocalUpdate_meta_q_our(args=args, dataset=concat_train, idxs=val_dict[idx])
            stat, G = local_meta.train(net=copy.deepcopy(net_our_meta).to(args.device), s_loss = s_loss_app[count], s_correct = s_correct_app[count], s_num_sample = s_num_sample_app[count], balance = args.balancer, std = var_loss)
            count += 1
            num_size.append(stat['query_num_samples'] + stat['support_loss_sum'])
            solns.append(G)

        w_glob = meta_agg(solns, list(net_our_meta.parameters()), num_size)
        # copy weight to net_glob for next round initialization
        for i in range(len(list(net_our_meta.parameters()))):
            with torch.no_grad():
                list(net_glob.parameters())[i] = w_glob[i]
        


        save_path = log_path_save + '/final_mdl.pt'
        torch.save(net_glob, save_path)
        
    print('done pre-training')
