import os
import pandas as pd
import numpy as np

from agents import GameRunner, ModelBuilder, PrivacyRegulator, FairnessRegulator, Interaction



def init_game(args):
    '''
        Initialize the game including the agents and relevant variables
    '''
    # load the table format loss functions
    print("Loading the losses", flush=True)
    loss_dir = args.prev_results_dir+"loss_functions/"+args.dataset+"/"+args.algorithm
    loss_builder_acc = np.load(loss_dir+'/builder_loss_acc.npy')  
    loss_privacy = np.load(loss_dir+'/privacy_loss.npy')
    loss_fairness = np.load(loss_dir+'/fairness_loss.npy')
    priv_fair_values = np.load(loss_dir+'/priv_fair_values.npy')
    # also load the corresponding input epsilon and gamma
    priv_values = priv_fair_values[:,0]
    fair_values = priv_fair_values[:,1]
    # other methods may not have coverage
    if args.algorithm == 'fairPATE':
        loss_builder_cov = np.load(loss_dir+'/builder_loss_cov.npy')
        losses = np.squeeze(np.stack((-1 * loss_builder_acc, loss_privacy, loss_fairness, -1 * loss_builder_cov), axis=-1))
    else:
        losses = np.squeeze(np.stack((-1 * loss_builder_acc, loss_privacy, loss_fairness), axis=-1))

    # initialize the agents without the loss functions
    model_builder = ModelBuilder(args)
    privacy_regulator = PrivacyRegulator(args)
    fairness_regulator = FairnessRegulator(args)
    # game runner takes in the cost function (outputs) and the input parameters
    game_runner = GameRunner(args, 
                             losses,
                             priv_values, 
                             fair_values,
                             [model_builder, privacy_regulator, fairness_regulator],
                             args.calibration)
    
    return game_runner, model_builder, privacy_regulator, fairness_regulator
    

def run_game(args, game_runner, model_builder, privacy_regulator, fairness_regulator):
    '''
        Run the game
    '''
    def log_init(game_runner, model_builder, curr_param, C_priv, C_fair):
        # log the initial params and losses
        loss_combined, loss_b, loss_p, loss_f, acc, cov = model_builder.get_losses(curr_param, C_priv, C_fair)
        inter_init = Interaction(round=0)
        inter_init.round_params = curr_param
        inter_init.losses = [loss_combined, loss_b, loss_p, loss_f, acc, cov]
                
        # log the round
        game_runner.register_interaction(inter_init)
        game_runner.results_to_df()
                
    if args.init_priv and args.init_fair: 
        curr_param = [args.init_priv, args.init_fair]
    else:
        curr_param = None

    # check if there is intermediate results
    file_path = args.save_path+args.experiment_name+'/df.parquet.gzip'
    # check if the file exist
    if not os.path.exists(file_path):
        curr_round = 1
    else:
        # read the last result
        df_inter_results = pd.read_parquet(file_path)
        df_last = df_inter_results.tail(1)
        curr_round = int(df_last['round']) + 1
        # game runner update stuff
        game_runner.sync(int(df_last['t']) + 1, df_inter_results)
        curr_param = [float(df_last['epsilon']), float(df_last['gamma'])]

        model_builder.update_step_size(curr_round)

    for i in range(curr_round, args.num_rounds+1):
        print("Game round "+str(i)+"---------------------------------------------")
        # game runner takes the PF and distributes the loss functions
        if args.calibration or i == 1:
            # if no calibration then game runner only needs to distribute the loss functions once
            game_runner.distribute_losses()
        # initialize a variable to track the current interaction
        inter = Interaction(round=i)
        # check which agent moves first
        if args.priority == 'regulators':
            # check if this is the first round
            if i == 1:
                # pick the initial parameters in the first round
                if not curr_param:
                    curr_param = game_runner.regulators_starting_point()
                print("Initial parameters chosen by regulators: "+str(curr_param), flush=True)
                C_priv = args.C_priv
                C_fair = args.C_fair
                
                # log the initial params and losses
                log_init(game_runner, model_builder, curr_param, C_priv, C_fair)

                
            # model builder takes a step
            curr_param, loss_combined, loss_b, loss_p, loss_f, acc, cov = model_builder.best_response(curr_param, C_priv, C_fair)
           
        else:
            # check if this is the first round
            if i == 1:
                if not curr_param:
                    curr_param = model_builder.choose_starting_point()
                print("Initial parameters chosen by model builder: "+str(curr_param), flush=True)
                C_priv = args.C_priv
                C_fair = args.C_fair
                loss_combined, loss_b, loss_p, loss_f, acc, cov = model_builder.get_losses(curr_param, C_priv, C_fair)
                # log the initial params and losses
                log_init(game_runner, model_builder, curr_param, C_priv, C_fair)
            else:
                # model builder takes a step
                curr_param, loss_combined, loss_b, loss_p, loss_f, acc, cov = model_builder.best_response(curr_param, C_priv, C_fair)
            
        # update interaction
        inter.round_params = curr_param
        inter.losses = [loss_combined, loss_b, loss_p, loss_f, acc, cov]
        
        # log the round
        game_runner.register_interaction(inter)
        game_runner.results_to_df()
        
        # calibration round
        if args.calibration:
            student_result = game_runner.train_student_model(curr_param)
            game_runner.update_losses(student_result)
            game_runner.calibration_to_df(student_result)
        
        print(inter, flush=True)
        
        # save now in case of preemption, writing one round at a time
        results_df = game_runner.return_results_df()
        results_df.to_parquet(args.save_path+args.experiment_name+'/df.parquet.gzip', compression='gzip')     

    return game_runner.return_results_df()


def run_simulation(args):
    # initialize the game
    print("Initializing the game", flush=True)
    game_runner, model_builder, privacy_regulator, fairness_regulator = init_game(args)
    print("Running the game", flush=True)
    return run_game(args, game_runner, model_builder, privacy_regulator, fairness_regulator)
    
