# -*- coding: utf-8 -*-
import copy
import ray
import torch
import numpy as np
from torch import optim, nn
from optimizer.mr_sgd import MR_SGD

from models.regularization import Regularization
from models.HyperNetwork import HyperNetwork

from typing import OrderedDict, Union

@ray.remote(num_cpus=1)
class Client(object):
    def __init__(self, client_index, args, model, train_loader, test_loader):

        self.p_list = range(len(21))
        self.args = args
        self.client_index = client_index
        self.device = args.device
        
        self.model = copy.deepcopy(model).to(self.device) 
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr, 
                                   momentum=0.9, weight_decay=5e-4) 
        self.mr_optimizer = MR_SGD(self.model.parameters(), lr=self.args.lr) 
        self.train_loader = train_loader 
        self.test_loader = test_loader 
        self.layer_name = list(self.model.get_weights().keys()) 
        
        if self.args.iter_method == 'iteration': # Iterator for 'iteration' method
            self.data_iteration = iter(self.train_loader) 
            
        self.criterion = nn.CrossEntropyLoss() 

        self.HN = HyperNetwork(args, self.layer_name, self.client_index)
        self.round = 0

        self.agged_model = {} 
        
        #######################################################################
        
    def get_model(self):
        return self.model.get_weights()
    
    def get_p(self):
        return self.HN(self.client_index)
        
    def train(self, p_loc):
        
        self.model.train()
        
        frz_model_params = clone_parameters(self.model) 
            
        # epoch method: train all mini-batch each round
        if self.args.iter_method == 'epoch':
            for batch_idx, (data, target) in enumerate(self.train_loader):
                # obtain data
                data, target = data.to(self.device), target.to(self.device)
                               
                # train
                self.optimizer.zero_grad()
                if self.args.model == 'LR' or self.args.model == 'LeNet':
                    output = self.model(data)
                else:
                    output, embedding = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                
                    
                self.optimizer.step()
                
            
        # iteration mode: train 1 mini-batch each round
        elif self.args.iter_method == 'iteration':
            # obtain data
            try:
                inputs, targets = next(self.data_iteration)
            except StopIteration:
                self.data_iteration = iter(self.train_loader)
                inputs, targets = next(self.data_iteration) 
                
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # train
            self.optimizer.zero_grad()
            if self.args.model == 'LR' or self.args.model == 'LeNet':
                outputs = self.model(inputs)
            else:
                outputs, embedding = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            if self.args.model == 'Lasso':
                # L1 Regularization
                loss += Regularization(self.model, 0.001, p=1)(self.model)
            elif self.args.model == 'RR':
                # L2 Regularization
                loss += Regularization(self.model, 0.001, p=0)(self.model)
            
            loss.backward()
            self.optimizer.step()
            
            # _, predicted = outputs.max(1)
 
        self.delta = OrderedDict(
            {k: p1 - p0 for (k, p1), p0 in zip(self.model.state_dict(keep_vars=True).items(), frz_model_params.values())})
        
        # P Power
        p1 =  self.p_list[p_loc[0]]
        p2 =  self.p_list[p_loc[1]]
        powed_model = model_power(self.model.get_weights(), p1, p2)

        return powed_model, self.model.get_gradients()
    
    def set_model(self, model):
        self.model.set_weights(model)

    def gradients_step(self, gradients, scale_factor, p_loc):

        self.mr_optimizer.zero_grad()
        self.model.set_gradients(gradients)
        p1 =  self.p_list[p_loc[0]]
        p2 =  self.p_list[p_loc[1]]
        scale_factor = pow(10, -(np.mean(p1, p2))) if p1 != 1  and p2 != 1 else 1

        self.mr_optimizer.step(scale_factor, self.layer_name)

        powed_model = model_power(self.model.get_weights(), 1/p1, 1/p2)

        
        return powed_model

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(self.test_loader):
                if self.args.model == 'LR' or self.args.model == 'LeNet':
                    outputs = self.model(inputs)
                else:
                    outputs, embedding = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)
    
    
    def HN_update(self):
        self.HN.update(self.delta, self.agged_model.values())
        
    def aggregation(self, all_models_list, weight_matrix):
        
        alpha = {}
        names = set(n.split(".")[0][:-1] for n in self.layer_name)
        for name in names:
            alpha[name] = torch.tensor([1/self.args.world_size for _ in range(self.args.world_size)])
        

        for layer in self.layer_name:

            layer_alpha = alpha[layer.split(".")[0][:-1]] * torch.tensor(weight_matrix)
            weight = layer_alpha/sum(layer_alpha)
            alpha[layer.split(".")[0][:-1]] = weight

            layer_param = [all_models_list[i][layer] for i in range(len(all_models_list))]
            self.agged_model[layer] = weighted_sum(layer_param, weight)

        return alpha, self.agged_model


# Clients End
###########################################################################################################
        
def all_train(clients, p_loc_list):

    all_models_list = []
    all_gradients_list = []
    
    process_ids = []

    for i in range(len(clients)):
        weight_id = clients[i].train.remote(p_loc_list[i])
        process_ids.append(weight_id)

    ray.wait(process_ids, num_returns=len(process_ids))

    for object_id in process_ids:
        model, gradient = ray.get(object_id)

        all_models_list.append(model)
        all_gradients_list.append(gradient)

    return all_models_list, all_gradients_list

def all_get_p(clients):
    all_p_list = []
    process_ids = [] 

    for i in range(len(clients)):
        pid = clients[i].get_p.remote()
        process_ids.append(pid)

    ray.wait(process_ids, num_returns=len(process_ids))

    for object_id in process_ids:
        p = ray.get(object_id)
        all_p_list.append(p)

    return all_p_list

def all_aggregation(clients, all_models_list, weight_matrix):
    process_ids = []
    for i in range(len(clients)):
        pid = clients[i].aggregation.remote(all_models_list, weight_matrix[i][:])
        process_ids.append(pid)

    ray.wait(process_ids, num_returns=len(process_ids))
    
    all_agged_model = []
    all_alpha = []
    for object_id in process_ids:
        alpha, agged_model = ray.get(object_id)
        all_agged_model.append(agged_model)
        all_alpha.append(alpha)

    return all_agged_model, all_alpha

def all_set_model(clients, agg_models_list):
    # second gradient descent with perious gradients 
    for index in range(len(clients)):
        clients[index].set_model.remote(agg_models_list[index])

def all_grad_step(layer_name, clients, agg_models_list, all_gradients_list, p_loc_list):
    process_ids = []
    
    # second gradient descent with perious gradients 
    for index in range(len(clients)):
        clients[index].set_model.remote(agg_models_list[index])
        
        model_id = clients[index].gradients_step.remote(all_gradients_list[index], p_loc_list[index])
        process_ids.append(model_id)
        
    ray.wait(process_ids, num_returns=len(process_ids))
    all_models_list = []
    
    for idx, object_id in enumerate(process_ids):
        final_model = ray.get(object_id)
        all_models_list.append(final_model)

    # set weights
    for index in range(len(clients)):
        clients[index].set_model.remote(all_models_list[index])
        
    return all_models_list

def all_test(clients, weight_matrix):
    total_acc = 0
    total_loss = 0
        
    process_ids = [] 
    for index in range(len(clients)):
        object_id = clients[index].test.remote()
        process_ids.append(object_id)
        
    ray.wait(process_ids, num_returns=len(process_ids))
    idx = 0
    for object_id in process_ids:
        idx += 1
        acc, loss = ray.get(object_id)
        total_acc += acc
        total_loss += loss
        
    avg_acc = format(total_acc/len(clients),'.4f')
    avg_loss = format(total_loss/len(clients),'.4f')

    return avg_acc, avg_loss

def all_HN_update(clients):
    process_ids = [] 
    for index in range(len(clients)):
            
        object_id = clients[index].HN_update.remote()
        process_ids.append(object_id)
        
    ray.wait(process_ids, num_returns=len(process_ids))
######################################################################

def weighted_sum(tensors, weights):
    assert len(tensors) == len(weights), "The number of tensors must match the number of weights."
    
    weighted_tensors = [tensor * weight for tensor, weight in zip(tensors, weights)]
    result = torch.sum(torch.stack(weighted_tensors), dim=0)
    
    return result


def tensor_power(tensor: torch.tensor, p: float):
    singp = torch.sign(tensor)
    temp = (tensor.abs())**p
    temp.mul_(singp)
    return temp

def model_power(model, p1, p2):
    for name in model.keys():
        if 'conv' in name:
            model[name] = tensor_power(model[name], p1)
        else: 
            model[name] = tensor_power(model[name], p2)
    return model

def clone_parameters(
    src: Union[OrderedDict[str, torch.Tensor], torch.nn.Module]
) -> OrderedDict[str, torch.Tensor]:
    if isinstance(src, OrderedDict):
        return OrderedDict(
            {
                name: param.clone().detach().requires_grad_(param.requires_grad)
                for name, param in src.items()
            }
        )
    if isinstance(src, torch.nn.Module):
        return OrderedDict(
            {
                name: param.clone().detach().requires_grad_(param.requires_grad)
                for name, param in src.state_dict(keep_vars=True).items()
            }
        )
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        