import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib
matplotlib.use('Agg')
import copy
import numpy as np
from torch.utils.data import DataLoader, Dataset
from AFL_runner import *
from util import *
import pickle

def FedAvg_FedFair(times,args,zeta=0.05,alpha=0.09,depth=8,k=300,model='Select_1',generate_method='random',generate_param=1):
    """"
    FedAvg+Fair
    Input:
        times: number of experiments
        args: hyperparameters and 
        zeta: fairness constraint confidence level
        alpha: fairness constraint
        depth: depth of Q digest
        k: compression parameter of Q digest
    """
    for random_seed in range(110,times):
        try:
            dataset_train,dataset_test,dict_users,test_dict_users,Sensitive_attribute_train,Sensitive_attribute_test=data_processing(data_name=args.dataset,data_path=args.data_path,\
                                                                                                                                 num=args.num_users, train_rate=0.8,random_seed=random_seed, \
                                                                                                                                    generate_method=generate_method,dir_alpha=generate_param)
        except:
            print('Random seed:',random_seed,'failed')
            continue
        net_glob=FedAvg_model(dataset_train=dataset_train,dict_users=dict_users,args=args,random_seed=random_seed)
        
        #save global model
        torch.save(net_glob.state_dict(), args.out_path+'FedAvg_{}_alpha_{}.pth'.format(random_seed,alpha))
        
        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEOO',model=model)
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEOO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEOO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)

        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEO',model=model)
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)
            print('Fair_{} finished'.format(random_seed))

def AFL_FedFair(times,args,zeta=0.05,alpha=0.09,depth=8,k=300,model='Select_1',generate_method='random',generate_param=1):
    """"
    AFL+Fair
    Input:
        times: number of experiments
        args: hyperparameters and 
        zeta: fairness constraint confidence level
        alpha: fairness constraint
        depth: depth of Q digest
        k: compression parameter of Q digest
    """
    for random_seed in range(times):
        print('Fair_{} start'.format(random_seed))
        args.seed_num=random_seed
        set_global_seeds(args.seed_num)
        try:
            dataset_train,dataset_test,dict_users,test_dict_users,Sensitive_attribute_train,Sensitive_attribute_test=data_processing(data_name=args.dataset,data_path=args.data_path,\
                                                                                                                                 num=args.num_users, train_rate=0.8,random_seed=random_seed, \
                                                                                                                                    generate_method=generate_method,dir_alpha=generate_param)
        except:
            print('Random seed:',random_seed,'failed')
            continue
        net_glob=AFL_model(dataset_train=dataset_train,dataset_test=dataset_test,dict_users=dict_users,test_dict_users=test_dict_users,args=args,random_seed=random_seed)
        
        if args.load_model:
            params=torch.load(args.out_path+'AFL_{}_alpha_{}.pth'.format(random_seed,alpha))
            net_glob.load_state_dict(params)
        #save global model
        else:
            torch.save(net_glob.state_dict(), args.out_path+'AFL_{}_alpha_{}.pth'.format(random_seed,alpha))
        
        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEOO',model=model)
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEOO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEOO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)

        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEO',model=model)
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)
            print('Fair_{} finished'.format(random_seed))




#-----------------------------------------------FedAvg----------------------------------------------
class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label
    
class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        
        
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                #print(images.shape)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        #self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.layer_input(x)
        #x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

    def pred_prob(self, x):
        x = self.forward(x)
        x = nn.functional.softmax(x, dim=0)
        return x

class CelebAModel(nn.Module):
    def __init__(self):
        super(CelebAModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x
    
    def pred_prob(self, x):
        x=x.unsqueeze(0)
        x = self.forward(x)[0]
        x = nn.functional.softmax(x, dim=0)
        return x
    

def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


def FedAvg_model(dataset_train,dict_users,args,random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed( random_seed )

    img_size = dataset_train[0][0].shape
    if args.dataset == 'celeba':
        net_glob = CelebAModel().to(args.device)
    else:
        if args.model == 'mlp':
            len_in = 1
            for x in img_size:
                len_in *= x
            net_glob = MLP(dim_in=len_in, dim_hidden=args.hidden, dim_out=args.num_classes).to(args.device)
        else:
            exit('Error: unrecognized model')
        print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()

    # training
    loss_train = []

    if args.all_clients: 
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]
    for iter in range(args.epochs):
        loss_locals = []
        if not args.all_clients:
            w_locals = []
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        # update global weights
        w_glob = FedAvg(w_locals)

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)
    return net_glob

#-----------------------------------------------AFL----------------------------------------------

def AFL_model(dataset_train,dataset_test,dict_users,test_dict_users,args,random_seed=0):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    np.random.seed( random_seed )
    
    #load dataset and split dataset to local clients
    train_dataset=dataset_train
    test_dataset=dataset_test
    args.train_distributed_data=dict_users
    args.test_distributed_data=test_dict_users
    
    # #training
    for epoch in range(args.global_epochs):
        net=runner_train(args, train_dataset, test_dataset, epoch+1)
        if args.load_model:
            break
    
    return net


#-----------------------------------------------FedFB----------------------------------------------

