from simulation.data_fico import get_Py1x_given, get_Px_given, get_Ps_given
from simulation.optim_utils import get_kernel_from_policy, get_stationary_dist_from_kernel

from simulation.constants import Cte
from simulation.problem import FairOptimizationProblem
from simulation.constraints import get_group_EOP, get_group_qualifications, get_total_qualifications, get_total_utility

import argparse
import numpy as np
import os
import time


class MaxUtilEOPProblem(FairOptimizationProblem):
    def __init__(self, num_cat, seed, slack, eps, verbose, c, data_name, lambda_fair, dynamics, exp_name, estimation):
        # call superclass with the arguments in init
        super().__init__(num_cat, seed, slack, eps, verbose, c, data_name, lambda_fair, dynamics, estimation)

    def constraint0(self, x):
        seed = self.seed
        num_cat = self.num_cat

        Py1x = get_Py1x_given(num_cat)
        kernel = get_kernel_from_policy(x, num_cat, seed, self.dynamics, self.estimation)
        Pd1 = x.reshape(2, num_cat)

        EOP = []
        for s in [0, 1]:
            EOP.append(get_group_EOP(kernel, s, Py1x, num_cat, Pd1))

        res = self.slack - (EOP[0] - EOP[1])**2
        return res

    #overriding
    def get_specific_constraints(self):
        con0 = {'type': 'ineq', 'fun': self.constraint0}
        return [con0]


    def objective(self, x):

        seed = self.seed
        num_cat = self.num_cat

        c = self.c

        Pd1 = x.reshape(2, num_cat)
        Py1x = get_Py1x_given(num_cat)
        Ps = get_Ps_given()
        kernel = get_kernel_from_policy(x, num_cat, seed, self.dynamics, self.estimation)

        util = get_total_utility(kernel, Py1x, Ps, Pd1, num_cat, c)



        return - util*100

if __name__ == '__main__':
    print(" -------------------- START MAIN --------------------")
    # Record the starting time
    start_time = time.time()
    parser = argparse.ArgumentParser()


    parser.add_argument('--verbose', '-v', action='store_true', default=False, help='Print on ')

    #####################################
    # CHANGE THESE TO GET DIFF. RESULTS #
    #####################################
    # num_cat, seed, slack, optim_for_s, fairness, op_eps
    parser.add_argument('--num_cat', '-n', type=int, default=4, help='number of X categories, needs to be even  ')
    parser.add_argument('--seed', '-s', type=int, default=6, help='Seed for random initializatinos of distirbutions')
    parser.add_argument('--slack', '-e', type=float, default=0.2, help='Unfairness tradeoff')

    parser.add_argument('--data_name', '-d', type=str, default=Cte.FICO, help='SYNTH for synthetic or FICO for fico data')
    parser.add_argument('--c', '-c', type=float, default=0.3, help='cost of positive prediction')
    parser.add_argument('--lambda_fair', '-l', type=float, default=0.3, help='weight of fairness in objective')
    parser.add_argument('--dynamics', '-dyn',  type=str, default=Cte.DEFAULT,  help='weight of fairness in objective')
    parser.add_argument('--exp_name', type=int, help='experiment name for saving the policies')
    parser.add_argument('--estimation', type=str, default="_true", help='using estimated distributions and dynamics or true, \
                                                                        can be _true, _est-random, _est-threshold.')

    #####################################
    # FIXED #
    #####################################
    parser.add_argument('--eps', '-eps', type=float, default=1.4901161193847656e-10,
                        help='epsilon for optimization algorithm')

    args = vars(parser.parse_args())

    solver = MaxUtilEOPProblem(**args)


    Pd1 = solver.solve()

    print(args['exp_name'])
    if args['exp_name'] == 1:
        exp = '01_initial-states'
    elif args['exp_name'] == 2:
        exp = '02_policies-fairness'
    elif args['exp_name'] == 3:
        exp = '03_dynamics-speed'
    elif args['exp_name'] == 4:
        exp = '04_dynamics-type'
    elif args['exp_name'] == 5:
        exp = '05_estimation'
    else:
        raise ValueError('Experiment name not valid')


    # create the path
    my_path = f"policies/{exp}"
    os.makedirs(my_path, exist_ok=True)


    np.savetxt(f"policies/{exp}/MaxEOP_{solver.slack}_{solver.c}_{solver.seed}_{solver.dynamics}{args['estimation']}.csv", Pd1.round(4), delimiter=",")
    print("resulting policy", Pd1.round(4))

    # Record the ending time
    end_time = time.time()

    # Calculate the elapsed time
    elapsed_time = end_time - start_time

    # Print the elapsed time in seconds
    print(f"Elapsed time: {elapsed_time} seconds")


    print(" -------------------- END MAIN --------------------")

