from dataclasses import dataclass

import numpy as np
import pandas as pd
import torch
import math
from scipy.interpolate import griddata

from main import train_student_governance_game



@dataclass
class Interaction():
    round: int
    round_params: float=0, 
    losses: float=0


@dataclass
class Metric():
    epsilon: float
    gamma: float
    achieved_epsilon: float
    achieved_gamma: float
    accuracy: float
    coverage: float
        

class Agent():
    def __init__(self, args, name: str) -> None:
        self.args = args
        self.name = name
        self.algorithm = args.algorithm
        self.loss_b, self.loss_priv, self.loss_fair, self.priv_values, self.fair_values = None, None, None, None, None
        self.loss_b_inter, self.loss_priv_inter, self.loss_fair_inter, self.priv_values_inter, self.fair_values_inter = None, None, None, None, None
        self.achieved_priv, self.achieved_fair = None, None
        self.acc, self.cov = None, None
        self.acc_inter, self.cov_inter = None, None
        # builder lambdas for the regulators
        self.lambda_priv = args.lambda_priv
        self.lambda_fair = args.lambda_fair
        # in certain scenarios multiply the step size by a factor beforehand to ensure all agents have the same step size in the same round
        self.step_size = np.array([args.step_size_priv, args.step_size_fair])
        if args.priority == "model_builder" and name == "model_builder":  
            self.step_size = self.step_size * 1/self.args.step_size_decay
    
        
    def update_loss(self, loss_b, loss_priv, loss_fair, priv_values, fair_values, acc, cov=None):
        '''
            Update the loss functions
        '''
        # these are in table format
        self.loss_b =loss_b
        self.loss_priv =loss_priv
        self.loss_fair =loss_fair
        self.priv_values = priv_values
        self.fair_values = fair_values
        self.acc, self.cov = acc, cov
        
        
        if self.algorithm == 'fairPATE':
            losses_inter, self.priv_values_inter, self.fair_values_inter = self.interpolate_losses([loss_b, loss_priv, loss_fair, acc, cov], priv_values, fair_values)
            self.cov_inter = losses_inter[4]
        else:
            losses_inter, self.priv_values_inter, self.fair_values_inter = self.interpolate_losses([loss_b, loss_priv, loss_fair, acc], priv_values, fair_values)

        # need to do something to these
        loss_priv_adjusted = np.fmax(np.log10(losses_inter[1]/(self.args.goal_priv)), np.zeros(losses_inter[1].shape))
        loss_fair_adjusted = np.fmax(losses_inter[2]-self.args.goal_fair, np.zeros(losses_inter[2].shape))
        
        self.achieved_priv = losses_inter[1]
        self.achieved_fair = losses_inter[2]
        
        self.loss_b_inter = losses_inter[0]
        self.loss_priv_inter = loss_priv_adjusted
        self.loss_fair_inter = loss_fair_adjusted
        self.acc_inter = losses_inter[3]

    
    def interpolate_losses(self, losses, priv_values, fair_values):
        '''
            Interpolate the losses into a grid format
            :param priv_values: epsilon of points
            :param fair_values: gamma of points
            :param loss: loss value specific to each agent
        '''
        x = priv_values
        y = fair_values
        xi = np.linspace(x.min(), x.max(), 100)
        yi = np.linspace(y.min(), y.max(), 100)
        X,Y = np.meshgrid(xi,yi)
        losses_inter = []
        for z in losses:
            losses_inter.append(griddata((x,y),z,(X,Y), method='linear'))
        
        return losses_inter, xi, yi
        
        
    def best_response(self, curr_param, C_priv, C_fair):
        '''
            Calculates the gradient at a particular point using the loss array and
            takes a step according to the step size
            :param curr_param: current parameter. Calculates the gradient at this point
            :param C_priv: privacy penalty scalar
            :param C_fair: fairness penalty scalar
            :return: the updated params, builder loss, priv loss, fair loss, acc, cov
        '''
        
        def closest_point(target, points):
            array = np.asarray(points)
            return  (np.abs(array - target)).argmin()
        
        # calculates the gradient with respect to each variable
        interpolated_loss_b, interpolated_priv, interpolated_fair = self.loss_b_inter, self.priv_values_inter, self.fair_values_inter
        interpolated_loss_priv, interpolated_loss_fair = self.loss_priv_inter, self.loss_fair_inter
        
        if curr_param[0] <= self.args.goal_priv:
            l_priv = 0
        else:
            l_priv = self.lambda_priv * C_priv * interpolated_loss_priv
        
        if self.args.algorithm == 'fairPATE' and curr_param[1] <= self.args.goal_fair:
            l_fair = 0
        else:
            l_fair = self.lambda_fair * C_fair * interpolated_loss_fair
        
        grad = np.gradient(interpolated_loss_b + l_priv + l_fair)
        
        # find the index of the closest value
        index_priv = closest_point(curr_param[0], interpolated_priv)
        index_fair = closest_point(curr_param[1], interpolated_fair)

        # get the gradient at the current param
        curr_grad_priv = grad[1][index_fair, index_priv]
        curr_grad_fair = grad[0][index_fair, index_priv]
        
        # take a step
        step = np.multiply([curr_grad_priv, curr_grad_fair], self.step_size)

        # step size decay
        self.step_size = self.step_size * 1/self.args.step_size_decay 

        # check for nan
        if math.isnan(step[0]):
            step[0] = 0
        if math.isnan(step[1]):
            step[1] = 0
            
        # check if the params will be out of bound
        new_param = curr_param - step
        new_param[0] = max(1, new_param[0])
        new_param[1] = max(0.01, new_param[1])
        
        # find the index of the closest value of the NEW param
        index_priv = closest_point(new_param[0], interpolated_priv)
        index_fair = closest_point(new_param[1], interpolated_fair)
        
        # get the losses
        curr_loss_b = interpolated_loss_b[index_fair, index_priv]
        curr_achieved_p = self.achieved_priv[index_fair, index_priv] 
        curr_achieved_f = self.achieved_fair[index_fair, index_priv]
        curr_loss_p = interpolated_loss_priv[index_fair, index_priv]
        curr_loss_f = interpolated_loss_fair[index_fair, index_priv]
        curr_loss_combined = curr_loss_b + self.lambda_priv * C_priv * curr_loss_p + self.lambda_fair * C_fair * curr_loss_f

        if self.algorithm == 'fairPATE':
            curr_cov = self.cov_inter[index_fair, index_priv]
        else:
            curr_cov = 1
        return new_param, curr_loss_combined, curr_loss_b, curr_achieved_p, curr_achieved_f, self.acc_inter[index_fair, index_priv], curr_cov

    
    def get_losses(self, curr_param, C_priv, C_fair):
        def closest_point(target, points):
            array = np.asarray(points)
            return  (np.abs(array - target)).argmin()
        
        interpolated_loss_b, interpolated_priv, interpolated_fair = self.loss_b_inter, self.priv_values_inter, self.fair_values_inter
        interpolated_loss_priv, interpolated_loss_fair = self.loss_priv_inter, self.loss_fair_inter
        
        # find the index of the closest value
        index_priv = closest_point(curr_param[0], interpolated_priv)
        index_fair = closest_point(curr_param[1], interpolated_fair)
        
        # get the losses
        curr_loss_b = interpolated_loss_b[index_fair, index_priv]
        curr_achieved_p = self.achieved_priv[index_fair, index_priv]
        curr_achieved_f = self.achieved_fair[index_fair, index_priv]
        curr_loss_p = interpolated_loss_priv[index_fair, index_priv]
        curr_loss_f = interpolated_loss_fair[index_fair, index_priv]
        curr_loss_combined = curr_loss_b + self.lambda_priv * C_priv * curr_loss_p + self.lambda_fair * C_fair * curr_loss_f
        
        if self.algorithm == 'fairPATE':
            curr_cov = self.cov_inter[index_fair, index_priv]
        else:
            curr_cov = 1
        
        return curr_loss_combined, curr_loss_b, curr_achieved_p, curr_achieved_f, self.acc_inter[index_fair, index_priv], curr_cov

        
    def update_step_size(self, round):
        '''
            updates the step size to the correct one after restarting after preemption
        '''
        self.step_size = np.array([self.args.step_size_priv, self.args.step_size_fair])
        self.step_size = self.step_size * (1/self.args.step_size_decay)**round
    
    
class ModelBuilder(Agent):
    def __init__(self, args) -> None:
        super().__init__(args, "model_builder")
        # lambda used in the weighting of coverage and accuracy
        self.builder_lambda = args.builder_lambda
    
    def choose_starting_point(self):
        '''
            Choose a starting point that has the lowest loss
        '''
        min_index = np.argmin(self.losses)
        
        return [self.priv_values[min_index], self.fair_values[min_index]]
    
    
class Regulator():
    def __init__(self, args, name: str) -> None:
        self.args = args
        self.name = name
        if name == 'privacy_regulator':
            self.penalty_scalar = args.C_priv
            self.constraint = args.goal_priv
        elif name == 'fairness_regulator':
            self.penalty_scalar = args.C_fair
            self.constraint = args.goal_fair
    
    def get_penalty_scalar(self):
        return self.penalty_scalar
        
        
class PrivacyRegulator(Regulator):
    
    def __init__(self, args) -> None:
        super().__init__(args, "privacy_regulator")


class FairnessRegulator(Regulator):

    def __init__(self, args) -> None:
        super().__init__(args, "fairness_regulator")
        


class GameRunner():
    """
    Enteity to keep track of the game simulation
    """
    def __init__(self, args, losses, priv_values, fair_values, agents, calibration=True) -> None:
        self.args = args
        self.algorithm = args.algorithm
        # all in table format
        self.losses, self.priv_values, self.fair_values= losses, priv_values, fair_values
        
        self.pf_indices = None
        self.agents = agents
        self.interaction_history = []
        self.results_df = None
        self.calibratioin_df = None
        self.time = 0       # for logging results
            
        if self.algorithm == 'fairPATE':
            self.fair_var = "gamma"
        else:
            self.fair_var = "tau"
    
    
    def is_pareto_efficient(self, costs, return_mask = False):
        """
            Find the pareto-efficient points
            :param costs: An (n_points, n_costs) array
            :param return_mask: True to return a mask
            :return: An array of indices of pareto-efficient points.
                If return_mask is True, this will be an (n_points, ) boolean array
                Otherwise it will be a (n_efficient_points, ) integer array of indices.
        """
        is_efficient = np.arange(costs.shape[0])
        n_points = costs.shape[0]
        next_point_index = 0  # Next index in the is_efficient array to search for
        while next_point_index<len(costs):
            nondominated_point_mask = np.any(costs<costs[next_point_index], axis=1)
            nondominated_point_mask[next_point_index] = True
            is_efficient = is_efficient[nondominated_point_mask]  # Remove dominated points
            costs = costs[nondominated_point_mask]
            next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1
        if return_mask:
            is_efficient_mask = np.zeros(n_points, dtype = bool)
            is_efficient_mask[is_efficient] = True
            return is_efficient_mask
        else:
            return is_efficient
    
    
    def interpolate_loss(self, losses, priv_values, fair_values):
        '''
            Interpolate the losses into a grid format
            :param priv_values: epsilon of points
            :param fair_values: gamma of points
            :param loss: loss value specific to each agent
        '''
        x = priv_values[self.pf_indices]
        y = fair_values[self.pf_indices]
        losses = losses[self.pf_indices, :]
        xi = np.linspace(x.min(), x.max(), 40)
        yi = np.linspace(y.min(), y.max(), 40)
        X,Y = np.meshgrid(xi,yi)
        interpolated_losses = []
        for i in range(4):
            z = losses[:, i].flatten()
            Z = griddata((x,y),z,(X,Y), method='linear')
            interpolated_losses.append(Z)
        return interpolated_losses, xi, yi
    
    
    def update_losses(self, result):
        '''
            Accept a result array of the new student model. 
            Update its cost function and input parameter list
        '''
        # TODO: make sure the key names match
        self.losses = np.concatenate((self.losses, 
                                      np.array([[-1 * result['student_accuracy'],  
                                                result['achieved_epsilon'], 
                                                result['achieved_fairness_gaps'],
                                                -1 * result['coverage']]])), axis = 0)
        print([result['student_accuracy'], result['achieved_epsilon'], result['achieved_fairness_gaps'], result['coverage']], flush=True)
        self.priv_values = np.append(self.priv_values, result['epsilon'])
        self.fair_values = np.append(self.fair_values, result['fairness_gaps'])
                                       
        
    def distribute_losses(self):
        '''
            Select points on the PF and distribute them to the agents
        '''
        # select points on the PF
        pf_indices = self.is_pareto_efficient(self.losses)
        self.pf_indices = pf_indices
        pf_losses = self.losses[pf_indices, :]
        pf_priv = self.priv_values[pf_indices]
        pf_fair = self.fair_values[pf_indices]
        # the losses are already negative so there is no need to multiply by -1
        loss_priv = pf_losses[:, 1]
        loss_fair = pf_losses[:, 2]

        # give the agents new losses
        if self.algorithm == 'fairPATE':
            loss_builder_weighted = self.args.builder_lambda *0.01 * pf_losses[:, 0] + (1-self.args.builder_lambda) * pf_losses[:, 3]
            self.agents[0].update_loss(loss_builder_weighted, loss_priv, loss_fair, pf_priv, pf_fair, -1 * pf_losses[:, 0], -1 * pf_losses[:, 3])     # model builder
        else:
            self.agents[0].update_loss(0.01 * pf_losses[:, 0], loss_priv, loss_fair, pf_priv, pf_fair, -1 * pf_losses[:, 0])     # model builder

    
    def regulators_starting_point(self):
        '''
            Fairness and privacy regulators choose the starting point of the game jointly
        '''
        priv_losses = self.losses[self.pf_indices][:,1]
        fair_losses = self.losses[self.pf_indices][:,2]
        combined_losses = self.args.regulators_lambda * np.log(priv_losses/min(priv_losses)) + (1-self.args.regulators_lambda) * (fair_losses-min(fair_losses))
        
        pf_priv = self.priv_values[self.pf_indices]
        pf_fair = self.fair_values[self.pf_indices]
        
        min_index = np.argmin(combined_losses)
        
        return [pf_priv[min_index], pf_fair[min_index]]
    
        
    def register_interaction(self, interaction: Interaction):
        self.interaction_history.append(interaction)
        
        
    def results_to_df(self):
        '''
        Store the current round of game simulation results in a dataframe
        '''
        
        # game parameters
        if self.time == 0:
            self.results_df = pd.DataFrame(columns=["round", "epsilon", "gamma", "agent", "loss_build_combined",
                                           "loss_build", "accuracy", "coverage", "privacy cost", "max fairness gap"])

        # experiment results
        inter = self.interaction_history[-1]
            
        self.results_df = pd.concat([self.results_df, pd.DataFrame({'round': self.time, 
                                            'epsilon': inter.round_params[0], 
                                            self.fair_var: inter.round_params[1], 
                                            'agent': 'build',
                                            'loss_build_combined': inter.losses[0],
                                            'loss_build': inter.losses[1], 
                                            "accuracy": inter.losses[4], 
                                            "coverage": inter.losses[5], 
                                            "privacy cost": inter.losses[2], 
                                            "max fairness gap": inter.losses[3]}, index=[0])], ignore_index=True)
        self.time += 1
            
            
    def results_to_df_archived(self):
        '''
        Store the current round of game simulation results in a dataframe
        '''
        def closest_point_2d(target, x, y):
            '''
                returns the closest point using euclidean distance
            '''
            points = np.stack((x,y), axis=-1)
            dist_2 = np.sum((points - target)**2, axis=1)
            return np.argmin(dist_2)
        
        def closest_point(target, points):
            array = np.asarray(points)
            return  (np.abs(array - target)).argmin()
        
        # game parameters
        if self.time == 0:
            self.results_df = pd.DataFrame(columns=["t", "round", "subround", "epsilon", "gamma", "agent", 
                                           "loss_build", "acc", "cov", "loss_privReg", "loss_fairReg"])

        # experiment results
        inter = self.interaction_history[-1]
        # interpolate the loss
        interpolated_losses, interpolated_priv, interpolated_fair = self.interpolate_loss(self.losses, self.priv_values, self.fair_values)

        for p in inter.round_params:
            step_losses = []
            closest_priv = closest_point(p[0], interpolated_priv)
            closest_fair = closest_point(p[1], interpolated_fair)
            
            # get each loss, order: -acc, -cov, priv, fair
            for l in interpolated_losses:
                step_losses.append(l[closest_fair, closest_priv])
            
            loss_builder_weighted = self.args.builder_lambda * 0.01* step_losses[0] + (1-self.args.builder_lambda) * step_losses[1]
            self.results_df = pd.concat([self.results_df, pd.DataFrame({'t': self.time, 'round': self.time//self.sub_rounds, 'subround': self.time%self.sub_rounds, 
                                            'epsilon': p[0], 'gamma': p[1], 
                                            'agent': self.agents_dict[self.time%self.sub_rounds], # 4 subrounds including calibration
                                            'loss_build': loss_builder_weighted, 
                                            "acc": -1 * step_losses[0], 
                                            "cov": -1 * step_losses[1], 
                                            "loss_privReg": 0.1 * step_losses[2], 
                                            "loss_fairReg": step_losses[3]}, index=[0])], ignore_index=True)
            self.time += 1
    
    
    def calibration_to_df(self, results):
        '''
            Write the student model results to a calibration df
        '''
        if self.time == 1:
            self.calibration_df = pd.DataFrame(columns=["round", "epsilon", "gamma", "agent", "loss_build_combined",
                                           "loss_build", "accuracy", "coverage", "privacy cost", "max fairness gap"])
        loss_builder_weighted = -1 * (self.args.builder_lambda * 0.01* results['student_accuracy'] + (1-self.args.builder_lambda) * results['coverage'])
        loss_build_combined = loss_builder_weighted + self.args.lambda_priv * self.args.C_priv * math.log10(results['achieved_epsilon']/self.args.goal_priv) + self.args.lambda_fair * self.args.C_fair * (results['achieved_fairness_gaps']-self.args.goal_fair)
        self.results_df = pd.concat([self.results_df, pd.DataFrame({'round': self.time, 
                                            'epsilon': results['epsilon'], 
                                            self.fair_var: results['fairness_gaps'], 
                                            'agent': 'calibration',
                                            'loss_build_combined': loss_build_combined,
                                            'loss_build': loss_builder_weighted, 
                                            "accuracy": results['student_accuracy'], 
                                            "coverage": results['coverage'], 
                                            "privacy cost": results['achieved_epsilon'], 
                                            "max fairness gap": results['achieved_fairness_gaps']}, index=[0])], ignore_index=True)
    
    
    def return_results_df(self):
        return self.results_df
    
    def return_calibration_df(self):
        return self.calibration_df
        
        
    def train_student_model(self, param):
        result = train_student_governance_game(self.args, param)
        return result
    
    def sync(self, curr_time, results_df):
        # update the time and results df
        self.time = curr_time
        self.results_df = results_df
        # get all the new student model results from df and add them to the loss, priv, and fair
        for index, row in results_df.iterrows():
            if row['agent'] == 'calibration':
                results = {'student_accuracy': float(row['accuracy']),
                           'coverage': float(row['coverage']),
                           'achieved_epsilon': float(row['loss_privReg']),
                           'achieved_fairness_gaps': float(row['loss_fairReg']),
                           'epsilon': float(row['epsilon']),
                           'fairness_gaps': float(row['gamma'])
                    
                }
                self.update_losses(results)

            