
import copy
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from  dataclasses import dataclass
import numpy as np
from collections import OrderedDict
import pdb
import torch.nn.functional as F
import random

def runner_train(args, train_dataset, test_dataset,  epoch):
    
    local_params={}
    if epoch==1:
        global clients
        global master
        clients={}
        for client_id in range(args.n_clients):
            clients[client_id]=define_localnode(args,train_dataset, test_dataset, client_id)
        img_size = train_dataset[0][0].shape
        if args.model == 'mlp':
            len_in = 1
            for x in img_size:
                len_in *= x
            args.input=len_in
        else:
            exit('Error: unrecognized model')
        master=define_globalnode(args)
    
    for client_id, client in clients.items():
        
        #distribute global weight to client
        global_weight=master.distribute_weight()
        copied_global_weight=copy.deepcopy(global_weight)


        #distribute global weight to client and start local round(weight update etc...)
        local_param=client.localround(copied_global_weight,epoch)
        local_params[client_id]=local_param
        
    master.aggregate(local_params)
    print('Round {:3d} finished'.format(epoch))
    if epoch==args.global_epochs:
        print("\nFinal Results")
        for client_id, client in clients.items():
            global_weight=master.distribute_weight()
            copied_global_weight=copy.deepcopy(global_weight)
            local_param=client.localround(copied_global_weight,epoch,validation_only=True)

    return master.model

def set_global_seeds(seed_number):
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_number)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    random.seed(seed_number)
    np.random.seed(seed_number)
    torch.manual_seed(seed_number)

def device_check(on_cuda):
    if torch.cuda.is_available():
        print("GPU will be used for training\n")
    else:
        if on_cuda:
            message = "GPU is not available"
            raise ValueError(message)
        message = "Warning!: CPU will be used for training\n"
        print(message, flush=True)  





@dataclass
class ClientsParams:
    weight : OrderedDict = None
    afl_loss  : float = None

class CreateDataset(Dataset):
    """
    An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, torch.tensor(label)


class LocalBase():
    def  __init__(self,args,train_dataset,test_dataset,client_id):
        self.args = args
        self.client_id = client_id
        self.trainDataset=CreateDataset(train_dataset, args.train_distributed_data[client_id])
        self.testDataset=CreateDataset(test_dataset, args.test_distributed_data[client_id])
        self.trainDataloader=DataLoader(self.trainDataset, args.batch_size, shuffle=True)
        self.testDataloader=DataLoader(self.testDataset, args.batch_size, shuffle=True)
        self.device = 'cuda' if args.on_cuda else 'cpu'
        self.criterion = nn.CrossEntropyLoss(reduction="mean")

        #self.show_class_distribution(self.trainDataset,self.testDataset,args)
    
    def show_class_distribution(self,train,test,args):
        print("Class distribution of id:{}".format(self.client_id))
        class_distribution_train=[ 0 for _ in range(args.num_classes)]
        class_distribution_test=[ 0 for _ in range(args.num_classes)]
        for _, c in train:
            class_distribution_train[c]+=1
        for _, c in test:
            class_distribution_test[c]+=1
        print("train",class_distribution_train)
        print("test",class_distribution_test)

    def local_validate(self,model):
        model.eval()
        model.to(self.device)
        correct = 0
        batch_loss = []
        with torch.no_grad():
            for images, labels in self.testDataloader:
                images, labels = images.to(self.device), labels.to(self.device)
                if images.shape[1]==1:
                    images=torch.cat((images, images, images), 1)
                output = model(images)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(labels.view_as(pred)).sum().item()
                loss = self.criterion(output, labels)
                batch_loss.append(loss.item())
        test_acc=100. * correct / len(self.testDataloader.dataset)
        test_loss=sum(batch_loss)/len(batch_loss)
        # print('| Client id:{} | Test_Loss: {:.3f} | Test_Acc: {:.3f}'.format(self.client_id,test_loss, test_acc))
            
        return test_acc, test_loss

    def update_weights(self,model,global_epoch):
        model.train()
        model.to(self.device)
        
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,momentum=self.args.momentum, weight_decay=5e-4)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr, weight_decay=1e-4)
        
        for epoch in range(1,self.args.local_epochs+1):
            
            batch_loss = []
            # correct = 0

            for batch_idx, (images, labels) in enumerate(self.trainDataloader):
                images, labels = images.to(self.device), labels.to(self.device)
                if images.shape[1]==1:
                    images=torch.cat((images, images, images), 1)

                optimizer.zero_grad()
                output = model(images)
                # pred = output.argmax(dim=1, keepdim=True)  
                # correct += pred.eq(labels.view_as(pred)).sum().item()
                
                loss = self.criterion(output, labels)
                loss.backward()
                #pdb.set_trace()

                optimizer.step()
                    
                #self.logger.add_scalar('loss', loss.item())、あとでどっかに学習のログ
                batch_loss.append(loss.item())

            # train_acc,train_loss=100. * correct / len(self.trainDataloader.dataset),sum(batch_loss)/len(batch_loss)
            # print('| Global Round : {}/{} | Client id:{} | Local Epoch : {}/{} |  Train_Loss: {:.3f} | Train_Acc: {:.3f}'.format(
            #             global_epoch,self.args.global_epochs, self.client_id, epoch,self.args.local_epochs,train_loss, train_acc))
        
        return model.state_dict()      
        
class Fedavg_Local(LocalBase):
    def __init__(self,args,train_dataset,val_dataset,client_id):
        super().__init__(args,train_dataset,val_dataset,client_id)
    
    def localround(self,model,global_epoch,validation_only=False):
        
        self.local_validate(model)
        if validation_only:
            return 
        #update weights
        self.updated_weight=self.update_weights(model,global_epoch)
        
        clients_params=ClientsParams(weight=self.updated_weight)
        
        return clients_params

class Afl_Local(LocalBase):
    def __init__(self,args,train_dataset,val_dataset,client_id):
        super().__init__(args,train_dataset,val_dataset,client_id)
        
    def localround(self,model,global_epoch,validation_only=False):

        _, test_loss=self.local_validate(model)
        if validation_only:
            return 
        #update weights
        self.updated_weight=self.update_weights(model,global_epoch)
        
        clients_params=ClientsParams(weight=self.updated_weight,afl_loss=test_loss)
        return clients_params


     
def define_localnode(args,train_dataset,val_dataset,client_id):
    if args.federated_type=='fedavg':#normal
        return Fedavg_Local(args,train_dataset,val_dataset,client_id)
        
    elif args.federated_type=='afl':#afl
        return Afl_Local(args,train_dataset,val_dataset,client_id)

    else:       
        raise NotImplementedError   
    

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        #self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = self.layer_input(x)
        #x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

    def pred_prob(self, x):
        x = self.forward(x)
        x = nn.functional.softmax(x, dim=0)
        return x

def mlp(args):
    return MLP(args.input, args.hidden, args.num_class)

class GlobalBase():
    def __init__(self, args):
        self.args = args
        self.device = 'cuda' if args.on_cuda else 'cpu'
        if args.model == 'mlp':
            self.model=mlp(args).to(self.device)
    
    def distribute_weight(self):
        return self.model


class Fedavg_Global(GlobalBase):
    def __init__(self, args):
        super().__init__(args)

    def aggregate(self,local_params):
        print("aggregating weights...")
        global_weight=self.model
        local_weights=[]
        for client_id ,dataclass in local_params.items():
            local_weights.append(dataclass.weight)
        w_avg=weighted_average_weights(local_weights,global_weight.state_dict())

        self.model.load_state_dict(w_avg)


class Afl_Global(GlobalBase):
    def __init__(self, args):
        super().__init__(args)
        self.lambda_vector= torch.Tensor([1/args.n_clients for _ in range(args.n_clients)])
        
    

    def aggregate(self,local_params):
        # print("aggregating weights...")
        global_weight=self.model
        local_weights=[]
        lambda_vector=self.lambda_vector
        loss_tensor = torch.zeros(self.args.n_clients)
        for client_id ,dataclass in local_params.items():
            loss_tensor[client_id]=torch.Tensor([dataclass.afl_loss])
            local_weights.append(dataclass.weight)

        lambda_vector += self.args.drfa_gamma * loss_tensor
        lambda_vector=euclidean_proj_simplex(lambda_vector)
        lambda_zeros = lambda_vector <= 1e-3
        if lambda_zeros.sum() > 0:
            lambda_vector[lambda_zeros] = 1e-3
            lambda_vector /= lambda_vector.sum()
        self.lambda_vector=lambda_vector
        w_avg=weighted_average_weights(local_weights,global_weight.state_dict(),lambda_vector.to(self.device))
        # print("lambda:",lambda_vector)
        self.model.load_state_dict(w_avg)


def define_globalnode(args):
    if args.federated_type=='fedavg':#normal
        return Fedavg_Global(args)
        
    elif args.federated_type=='afl':#afl
        return Afl_Global(args)
        
    else:       
        raise NotImplementedError   

def weighted_average_weights(local_weights,global_weight,coff=None):
    """
    Returns the average of the weights.
    """
    if coff is None:
        coff=np.array([1/len(local_weights) for _ in range(len(local_weights))])
    w_avg = copy.deepcopy(global_weight)
    for key in w_avg.keys():
        for i in range(len(local_weights)):
            if w_avg[key].dtype==torch.int64:
                continue
            w_avg[key] += coff[i]*(local_weights[i][key]-global_weight[key])
    return w_avg


def euclidean_proj_simplex(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # check if we are already on the simplex
    if v.sum() == s and (v >= 0).all():
        # best projection: itself!
        return v
    # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = torch.flip(torch.sort(v)[0],dims=(0,))
    cssv = torch.cumsum(u,dim=0)
    # get the number of > 0 components of the optimal solution
    non_zero_vector = torch.nonzero(u * torch.arange(1, n+1) > (cssv - s), as_tuple=False)
    if len(non_zero_vector) == 0:
        rho=0.0
    else:
        rho = non_zero_vector[-1].squeeze()
    # compute the Lagrange multiplier associated to the simplex constraint
    theta = (cssv[rho] - s) / (rho + 1.0)
    # compute the projection by thresholding v using theta
    w = (v - theta).clamp(min=0)
    return w  