import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import pandas as pd
from utils.sampling import mnist_iid, mnist_iid2, mnist_noniid, mnist_noniid2, cifar_iid, cifar_iid2, cifar_noniid, cifar_noniid2
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNCifar, GateCNN, GateMLP
from models.Fed import FedAvg
from models.test import test_img, test_img_mix
from collections import Counter
import os.path


if __name__ == '__main__':
    filexist = os.path.isfile('results') 
    if(not filexist):
        with open('results','a') as f1:
            f1.write('dataset;model;epochs;local_ep;num_users;iid;p;opt;val_acc_avg_mix;val_acc_avg_locals;val_acc_avg_fedavg;acc_test_fedavg;acc_test_locals;acc_test_mix;ft_valacc;ft_testacc;val_acc_avg_mixFed;acc_test_mixFed;run')

            f1.write('\n')
        
    args=args_parser()
    for run in range(args.runs):

        args.device = torch.device('cuda:{}'.format(args.gpu))
        
        # load dataset and split users
        if args.dataset == 'mnist':
            trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
            dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist) 
            dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
            # sample users
            if args.iid:
                dict_users = mnist_iid(dataset_train, args.num_users)
            else:
                dict_users = mnist_noniid2(dataset_train, args.num_users, args.p)
            
        elif args.dataset == 'cifar10':
            trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
            dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
            if args.iid:
                dict_users = cifar_iid(dataset_train, args.num_users)
            else:
                dict_users = cifar_noniid2(dataset_train, args.num_users, args.p)
                
        elif args.dataset == 'cifar100':
            trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            dataset_train = datasets.CIFAR100('../data/cifar100', train=True, download=True, transform=trans_cifar)
            dataset_test = datasets.CIFAR100('../data/cifar100', train=False, download=True, transform=trans_cifar)
            if args.iid:
                dict_users = cifar_iid(dataset_train, args.num_users)
            else:
                dict_users = cifar_noniid2(dataset_train, args.num_users, args.p)
        else:
            exit('error: dataset not available')
            

        #models    
        img_size = dataset_train[0][0].shape

        len_in = 1
        for x in img_size:
            len_in *= x
            
        if args.model == 'cnn' and (args.dataset == 'cifar10' or args.dataset == 'cifar100'):
            if(args.dataset == 'cifar10' or args.dataset == 'cifar100'):
                net_glob_fedAvg = CNNCifar(args=args).to(args.device)

            gates = []
            gates_e2e = []
            net_locals = []
            
            #opt-out fraction
            opt = np.ones(args.num_users)
            opt_out = np.random.choice(range(args.num_users), size = int(args.opt*args.num_users), replace=False)
            opt[opt_out] = 0.0

            for i in range(args.num_users):
                gates.append(GateCNN(args=args).to(args.device))
                gates_e2e.append(GateCNN(args=args).to(args.device))
                net_locals.append(CNNCifar(args=args).to(args.device))
                
        elif args.model == 'mlp':
            net_glob_fedAvg = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)

            gates = []
            net_locals = []
            gates_e2e = []
                                       
            opt = np.ones(args.num_users)
            opt_out = np.random.choice(range(args.num_users), size = int(args.opt*args.num_users), replace=False)
            opt[opt_out] = 0.0
            print(opt)
            for i in range(args.num_users):
                gates.append(GateMLP(dim_in = len_in,dim_hidden=200, dim_out=1).to(args.device))
                gates_e2e.append(GateMLP(dim_in = len_in,dim_hidden=200, dim_out=1).to(args.device))
                net_locals.append(MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device))
                
        else:
            exit('error: no such model')
        
        print(net_glob_fedAvg)
        for i in range(args.num_users):
            gates[i].train()
            net_locals[i].train()

            
        net_glob_fedAvg.train()

        # training
        val_acc_avg_locals, val_acc_avg_mix, val_acc_avg_fedavg =  [], [], []
        acc_test_locals, acc_test_mix, acc_test_fedavg = [], [], []
        
        val_acc_fedavg, val_acc_finetuned_avg, acc_test_finetuned_avg = [], [], []

        for iter in range(args.epochs):
            print('Round {:3d}'.format(iter))
            
            w_locals_fedAvg = []
            alpha = []

            for idx in range(args.num_users):
                print("FedAvg client %d" %(idx))
                
                local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
                
                if(opt[idx]):
                    #train FedAvg
                    w_glob_fedAvg, _ = local.train(net = copy.deepcopy(net_glob_fedAvg).to(args.device),n_epochs = args.local_ep)

                    w_locals_fedAvg.append(copy.deepcopy(w_glob_fedAvg))
                    alpha.append(len(dict_users[idx])/len(dataset_train))

            
            # update global model weights    
            w_glob_fedAvg = FedAvg(w_locals_fedAvg, alpha)

            # copy weight to net_glob
            net_glob_fedAvg.load_state_dict(w_glob_fedAvg)

        val_acc_locals, val_acc_mix, val_acc_fedavg, val_acc_e2e = [], [], [], []
        acc_test_l, acc_test_m = [], []
        gate_values = []
        for idx in range(args.num_users):

            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])  

            #finetune FedAvg for every client
            print("Finetune %d" %(idx))
            wt, _, val_acc_finetuned = local.train_finetune(net = copy.deepcopy(net_glob_fedAvg).to(args.device),n_epochs = 100)
            val_acc_finetuned_avg.append(val_acc_finetuned)

            #ft_net = copy.deepcopy(net_glob_fedAvg)
            #ft_net.load_state_dict(wt)
            #acc_test_finetuned, _ = test_img(ft_net, dataset_test, args)
            #acc_test_finetuned_avg.append(acc_test_finetuned)

            #train local model
            print("Local %d" %(idx))
            w_l, _, val_acc_l = local.train_finetune(net = net_locals[idx].to(args.device),n_epochs = 100)
            net_locals[idx].load_state_dict(w_l)
            
            val_acc_locals.append(val_acc_l)

            #acc_test_lk, _ = test_img(net_locals[idx], dataset_test, args)
            #acc_test_l.append(acc_test_lk)
            
            #Train mixture
            print("E2e %d" %(idx))
            w_gate_e2e, _, val_acc_e2e_k, _ = local.train_mix(net_local = copy.deepcopy(net_locals[idx]), net_global = copy.deepcopy(net_glob_fedAvg).to(args.device), gate = gates_e2e[idx].to(args.device), train_gate_only=False, n_epochs = 100, early_stop=True)
            val_acc_e2e.append(val_acc_e2e_k)


            #acc_test_mix_k, _ = test_img_mix(net_locals[idx], net_glob_fedAvg, gates[idx], dataset_test, args)
            #acc_test_m.append(acc_test_mix_k)       

            #evaluate FedAvg on local dataset
            val_acc_fed, _ = local.validate(net = net_glob_fedAvg.to(args.device))
            val_acc_fedavg.append(val_acc_fed)

            #val_acc_fedMix, _, _ = local.validate_mix(net_localsMix[idx], net_glob_fedAvgMix, gates[idx])
            #val_acc_mix.append(val_acc_fedMix)

            #acc_test_fedMix, _ = test_img_mix(net_localsMix[idx], net_glob_fedAvgMix, gates[idx], dataset_test, args)
            #acc_test_m.append(acc_test_fedMix)   


        #Calculate validation and test accuracies
        
        if(args.fedmix):
            val_acc_avg_mixFed = sum(val_acc_mixFed) / len(val_acc_mixFed)
            acc_test_mixFed = sum(acc_test_mfed)/len(acc_test_mfed)
        else:
            val_acc_avg_mixFed = np.nan
            acc_test_mixFed = np.nan
        
        val_acc_avg_locals = sum(val_acc_locals) / len(val_acc_locals)
        #acc_test_locals = sum(acc_test_l)/len(acc_test_l)
        
        #val_acc_avg_mix = sum(val_acc_mix) / len(val_acc_mix)
        val_acc_avg_mix = np.nan
        val_acc_avg_e2e = sum(val_acc_e2e) / len(val_acc_e2e)

        #acc_test_mix = sum(acc_test_m)/len(acc_test_m)
        

        
        val_acc_avg_fedavg = sum(val_acc_fedavg) / len(val_acc_fedavg)
        #acc_test_fedavg, _ = test_img(net_glob_fedAvg, dataset_test, args)
        
        ft_val_acc = sum(val_acc_finetuned_avg)/len(val_acc_finetuned_avg)
        #ft_test_acc = sum(acc_test_finetuned_avg)/len(acc_test_finetuned_avg)
        ft_test_acc = np.nan

        
        with open('results','a') as f1:
            f1.write('{};{};{};{};{};{};{};{};{};{};{};{};{};{};{};{};{};{};{}'.format(args.dataset, args.model, args.epochs, args.local_ep, args.num_users, args.iid, args.p, args.opt, val_acc_avg_mix, val_acc_avg_locals, val_acc_avg_fedavg, acc_test_fedavg, acc_test_locals, acc_test_mix, ft_val_acc, ft_test_acc, val_acc_avg_e2e, acc_test_mixFed, run))
            f1.write("\n")

