from simulation.data_fico import get_Py1x_given, get_Px_given, get_Ps_given
from simulation.optim_utils import get_kernel_from_policy, get_stationary_dist_from_kernel
from simulation.constants import Cte
from simulation.problem import FairOptimizationProblem
from simulation.constraints import get_total_qualifications, get_group_qualifications, get_total_utility
import time
import argparse
import numpy as np
import os

class MaxQualificationOptimizationProblem(FairOptimizationProblem):
    def __init__(self, num_cat, seed, slack, eps, verbose, c, data_name, lambda_fair, dynamics, exp_name, estimation):
        # call superclass with the arguments in init
        super().__init__(num_cat, seed, slack, eps, verbose, c, data_name, lambda_fair, dynamics, estimation)


    def constraint1(self, x):
        seed = self.seed
        num_cat = self.num_cat

        kernel = get_kernel_from_policy(x, num_cat, seed, self.dynamics, self.estimation)
        Py1x = get_Py1x_given(num_cat)
        Ps = get_Ps_given()
        c = self.c
        Pd1 = x.reshape(2, num_cat)

        util = get_total_utility(kernel, Py1x, Ps, Pd1, num_cat, c)
        return util


    def get_specific_constraints(self):
        # con0 = {'type': 'ineq', 'fun': self.constraint0}
        con1 = {'type': 'ineq', 'fun': self.constraint1}
        return [con1]


    def objective(self, x):

        seed = self.seed
        num_cat = self.num_cat
        Py1x = get_Py1x_given(num_cat)
        Ps = get_Ps_given()
        kernel = get_kernel_from_policy(x, num_cat, seed, self.dynamics, self.estimation)

        res = get_total_qualifications(kernel, Py1x, Ps, num_cat)

        return - 10*res


if __name__ == '__main__':
    print(" -------------------- START MAIN --------------------")
    start_time = time.time()
    parser = argparse.ArgumentParser()

    parser.add_argument('--verbose', '-v', action='store_true', default=False, help='Print on ')

    #####################################
    # CHANGE THESE TO GET DIFF. RESULTS #
    #####################################
    # num_cat, seed, slack, optim_for_s, fairness, op_eps
    parser.add_argument('--num_cat', '-n', type=int, default=4, help='number of X categories, needs to be even  ')
    parser.add_argument('--seed', '-s', type=int, default=6, help='Seed for random initializatinos of distirbutions')
    parser.add_argument('--slack', '-e', type=float, default=0.2, help='Unfairness tradeoff')

    parser.add_argument('--data_name', '-d', type=str, default=Cte.FICO,
                        help='SYNTH for synthetic or FICO for fico data')
    parser.add_argument('--c', '-c', type=float, default=0.3, help='cost of positive prediction')
    parser.add_argument('--lambda_fair', '-l', type=float, default=0.3, help='weight of fairness in objective')
    parser.add_argument('--dynamics', '-dyn', type=str, default=Cte.DEFAULT, help='weight of fairness in objective')
    parser.add_argument('--exp_name', type=int, help='experiment name for saving the policies')
    parser.add_argument('--estimation', type=str, default="_true", help='using estimated distributions and dynamics or true, \
                                                                         can be _true, _est-random, _est-threshold.')

    #####################################
    # FIXED #
    #####################################
    parser.add_argument('--eps', '-eps', type=float, default=1.4901161193847656e-10,
                        help='epsilon for optimization algorithm')

    args = vars(parser.parse_args())

    solver = MaxQualificationOptimizationProblem(**args)
    Pd1 = solver.solve()


    print(args['exp_name'])
    if args['exp_name'] == 1:
        exp = '01_initial-states'
    elif args['exp_name'] == 2:
        exp = '02_policies-fairness'
    elif args['exp_name'] == 3:
        exp = '03_dynamics-speed'
    elif args['exp_name'] == 4:
        exp = '04_dynamics-type'
    elif args['exp_name'] == 5:
        exp = '05_estimation'
    else:
        raise ValueError('Experiment name not valid')

    # create the path
    my_path = f"policies_old/{exp}"
    os.makedirs(my_path, exist_ok=True)

    np.savetxt(f"policies/{exp}/MaxQual_{solver.slack}_{solver.c}_{solver.seed}_{solver.dynamics}{args['estimation']}.csv", Pd1.round(4), delimiter=",")
    print("resulting policy", Pd1.round(4))

    # Record the ending time
    end_time = time.time()

    # Calculate the elapsed time
    elapsed_time = end_time - start_time

    # Print the elapsed time in seconds
    print(f"Elapsed time: {elapsed_time} seconds")

print(" -------------------- END MAIN --------------------")

