import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import init_param, make_batchnorm, loss_fn
from config import cfg

class SimpObjective(nn.Module):
    def __init__(self, theta_0, theta_1):
        super().__init__()
        self.theta_0 = nn.Parameter(torch.tensor(theta_0))
        self.theta_1 = nn.Parameter(torch.tensor(theta_1))
        

    def utility_income(self, x):
        # utility income function(U) that maps collaboration gain to money, known and same for all the clients
        # and the server. Currently is identity mapping
        return x

    def sigmoid(self, x):
        x = x/cfg['objective_func_sigmoid_s']
        torch_sigmoid = torch.nn.Sigmoid()
        return torch_sigmoid(x)
    
    
    def client_participation_cost(
        self, 
        client_id,
        client_participated_active_round_client_ig, 
        client_participated_active_round_server_cg,
        # cur_client_last_last_active_round_server_cg,
        malicious_client_ids,
        print_indicator=False
    ):  
       
        cost = self.theta_0 * client_participated_active_round_server_cg * (1 + cfg['active_rate'] * (-1 + cfg['pricing_plan_coefficient'] * self.sigmoid(client_participated_active_round_server_cg - client_participated_active_round_client_ig - self.theta_1)))
        
        return cost
        

    def indicator_variable(
        self, 
        client_id,
        last_round_server_cg,
        client_participated_active_round_client_ig, 
        client_participated_active_round_server_cg,
        # cur_client_last_last_active_round_server_cg,
        malicious_client_ids,
        print_indicator=False
    ):  
        first = self.utility_income(last_round_server_cg) 
        second = self.utility_income(client_participated_active_round_client_ig)
        third = self.client_participation_cost(
            client_id,
            client_participated_active_round_client_ig, 
            client_participated_active_round_server_cg,
            # cur_client_last_last_active_round_server_cg,
            malicious_client_ids,
            print_indicator=print_indicator,
        )
            
        indicator_variable = self.sigmoid(first - second - third)
        return indicator_variable
    

    def client_participation_indicator(
        self, 
        client_id,
        last_round_server_cg,
        client_participated_active_round_client_ig, 
        client_participated_active_round_server_cg,
        # cur_client_last_last_active_round_server_cg,
        malicious_client_ids,
        print_indicator=False
    ):  
        first = self.utility_income(last_round_server_cg) 
        second = self.utility_income(client_participated_active_round_client_ig)
        third = self.client_participation_cost(
            client_id,
            client_participated_active_round_client_ig, 
            client_participated_active_round_server_cg,
            # cur_client_last_last_active_round_server_cg,
            malicious_client_ids,
            print_indicator=print_indicator,
        )
        indicator_variable = first - second - third > 0
        return indicator_variable
    
    def marginal_gain(
        self, 
        client_id,
        client_participated_active_round_client_ig, 
        client_participated_active_round_server_cg,
        # cur_client_last_last_active_round_server_cg,
        cur_client_last_active_round_marginal_gain_second_term,
        malicious_client_ids
    ):  
        client_participation_cost = self.client_participation_cost(
            client_id,
            client_participated_active_round_client_ig, 
            client_participated_active_round_server_cg,
            # cur_client_last_last_active_round_server_cg,
            malicious_client_ids
        )
 
        marginal_gain = None
        if cfg['lambda'] == 0:
            marginal_gain = -cur_client_last_active_round_marginal_gain_second_term
        elif cfg['lambda'] == 999999:
            marginal_gain = cfg['lambda_for_infinite_situation'] * client_participation_cost
        else:
            marginal_gain = cfg['lambda'] * client_participation_cost - cur_client_last_active_round_marginal_gain_second_term
        return marginal_gain


    def forward(self, input):

        active_client_ids = copy.deepcopy(input['active_client_ids'])
        server_collaboration_gains = copy.deepcopy(input['server_collaboration_gains'])
        client_participated_active_round_client_igs = copy.deepcopy(input['client_participated_active_round_client_igs'])
        client_participated_active_round_server_cgs = copy.deepcopy(input['client_participated_active_round_server_cgs'])
        clients_last_active_round_marginal_gain_second_terms = input['clients_last_active_round_marginal_gain_second_terms']
        malicious_client_ids = copy.deepcopy(input['malicious_client_ids'])
        # clients_last_active_round_marginal_gain_second_terms = copy.deepcopy(input['clients_last_active_round_marginal_gain_second_terms'])
        active_client_losses = copy.deepcopy(input['active_client_losses'])
        malicious_client_ids_all = copy.deepcopy(input['malicious_client_ids_all'])
        global_epoch = input['global_epoch']
        theta_1_diff = input['theta_1_diff']
        theta_1 = input['theta_1']
        last_round_server_cg_1 = server_collaboration_gains[-1]
        # self.cal_theta1(theta_1_diff, malicious_client_ids_all, global_epoch, last_round_server_cg_1)


        # self.theta_1 = theta_1
        # self.theta_1 = nn.Parameter(torch.tensor(theta_1))
        Objective_t = 0
        for client_id in range(cfg['num_clients']):
            client_participated_active_round_client_ig = torch.as_tensor(client_participated_active_round_client_igs[client_id]).clone().detach().to('cpu')
            # print(f'client_id: {client_id}, cur_client_last_active_round: {cur_client_last_active_round}\n')
            client_participated_active_round_server_cg = torch.as_tensor(client_participated_active_round_server_cgs[client_id]).clone().detach().to('cpu')

            cur_client_last_active_round_marginal_gain_second_term = torch.as_tensor(clients_last_active_round_marginal_gain_second_terms[client_id]).clone().detach().to('cpu')
            
            # last_round_server_cg = torch.as_tensor(server_collaboration_gains[-1]).clone().detach().to('cpu')
            last_round_server_cg = torch.as_tensor(max(server_collaboration_gains)).clone().detach().to('cpu')
            # print(f'client_id: {client_id}, cur_client_last_active_round: {cur_client_last_active_round}\n')
            sub_objective = self.indicator_variable(
                client_id=client_id,
                last_round_server_cg=last_round_server_cg,
                client_participated_active_round_client_ig=client_participated_active_round_client_ig, 
                client_participated_active_round_server_cg=client_participated_active_round_server_cg,
                # cur_client_last_last_active_round_server_cg=cur_client_last_last_active_round_server_cg,
                malicious_client_ids=malicious_client_ids,
                print_indicator=True
            ) * self.marginal_gain(
                client_id=client_id,
                client_participated_active_round_client_ig=client_participated_active_round_client_ig, 
                client_participated_active_round_server_cg=client_participated_active_round_server_cg,
                # cur_client_last_last_active_round_server_cg=cur_client_last_last_active_round_server_cg,
                cur_client_last_active_round_marginal_gain_second_term=cur_client_last_active_round_marginal_gain_second_term,
                malicious_client_ids=malicious_client_ids,
                # clients_last_active_round_marginal_gain_second_terms
            )
            Objective_t -= sub_objective
        # print(f'Objective_t: {Objective_t}')
        return Objective_t



    def calculate_client_participation_decision(self, server_collaboration_gains, client_participated_active_round_client_igs, client_participated_active_round_server_cgs, malicious_client_ids):
        clients_indicator_list = []
        cur_round_participation_client_ids = []
        for client_id in range(cfg['num_clients']):
            client_participated_active_round_client_ig = client_participated_active_round_client_igs[client_id]
            client_participated_active_round_server_cg = client_participated_active_round_server_cgs[client_id]
            
            # last_round_server_cg = server_collaboration_gains[-1]
            last_round_server_cg = max(server_collaboration_gains)
            # print(f'client_id: {client_id}, cur_client_last_active_round: {cur_client_last_active_round}\n')
            indicator = self.client_participation_indicator(
                client_id=client_id,
                last_round_server_cg=last_round_server_cg,
                client_participated_active_round_client_ig=client_participated_active_round_client_ig, 
                client_participated_active_round_server_cg=client_participated_active_round_server_cg,
                # cur_client_last_last_active_round_server_cg=cur_client_last_last_active_round_server_cg,
                malicious_client_ids=malicious_client_ids,
                print_indicator = False
            ) 

            clients_indicator_list.append(indicator)
            if indicator == 1:
                cur_round_participation_client_ids.append(client_id)

        return clients_indicator_list

    def calculate_client_participation_cost(
            self, 
            cur_round_server_cg, 
            # client_participated_active_round_client_igs, 
            # client_participated_active_round_server_cgs, 
            cur_round_active_clients_igs,
            malicious_client_ids, 
            cur_round_participation_clients_id, 
            cur_round_active_client_ids,
        ):
        clients_participation_costs = []
        malicious_clients_participation_costs = []
        benign_clients_participation_costs = []
        cur_round_active_clients_participation_costs = []
        # cur_round_active_clients_igs = list(cur_round_active_clients_igs)
        for client_id in cur_round_participation_clients_id:
            if client_id in cur_round_active_client_ids:
                active_client_index = cur_round_active_client_ids.index(client_id)
                client_participated_active_round_client_ig = cur_round_active_clients_igs[active_client_index]
            # else:
            #     client_participated_active_round_client_ig = torch.as_tensor(0).clone().detach().to('cpu')
            # print(f'client_id: {client_id}, cur_client_last_active_round: {cur_client_last_active_round}\n')
            client_participated_active_round_server_cg = cur_round_server_cg

            if client_id in cur_round_active_client_ids:
                cost = self.theta_0 * client_participated_active_round_server_cg * (1 + (-1 + cfg['pricing_plan_coefficient'] * self.sigmoid(client_participated_active_round_server_cg - client_participated_active_round_client_ig - self.theta_1)))
            else:
                cost = self.theta_0 * client_participated_active_round_server_cg
            
            if torch.is_tensor(cost):
                cost = cost.item()

            clients_participation_costs.append(cost)

            if client_id in malicious_client_ids:
                malicious_clients_participation_costs.append(cost)
            else:
                benign_clients_participation_costs.append(cost)

            if client_id in cur_round_active_client_ids:
                cur_round_active_clients_participation_costs.append(cost)
           
        return clients_participation_costs, cur_round_active_clients_participation_costs, malicious_clients_participation_costs, benign_clients_participation_costs


