
from simulation.data_fico import get_Py1x_given, get_Px_given, get_Ps_given, get_Py1x_estimated
from simulation.optim_utils import get_Pd1_given, get_kernel_from_policy, get_stationary_dist_from_kernel

from scipy.optimize import minimize, Bounds
from numpy.linalg import matrix_power

import numpy as np

class FairOptimizationProblem:
    def __init__(self, num_cat, seed, slack,  eps, verbose, c, data_name, lambda_fair, dynamics, estimation):

        self.num_cat = num_cat
        self.seed = seed
        self.slack = slack

        self.verbose = verbose
        self.c = c # cost of positive prediction
        self.data_name = data_name

        self.lambda_fair = lambda_fair
        self.dynamics = dynamics
        self.estimation = estimation

        ## fix
        self.eps = eps
        # Asser that self.num_cat is even
        assert self.num_cat % 2 == 0, "num_cat must be even"

    def get_seed(self):
        return self.seed

    def get_num_cat(self):
        return self.num_cat

    def get_slack(self):
        return self.slack

    def get_optim_for_s(self):
        return self.optim_for_s

    def get_fairness(self):
        return self.fairness

    def get_op_eps(self):
        return self.op_eps

    # to be overriden by subclasses
    def get_bounds(self):
        num_var = self.num_cat * 2
        bounds = Bounds(np.ones(num_var) * 0.0, np.ones(num_var)*1)
        return bounds

    # to be overriden by subclasses
    def get_specific_constraints(self):
        return []

    # to be overriden by subclasses
    def objective(self):
        return 0


    # irreducability min > 0
    def constraint1(self, x):
        seed = self.seed
        num_cat = self.num_cat
        kernel = get_kernel_from_policy(x, self.num_cat, self.seed, self.dynamics, self.estimation)
        res = 0
        for s in [0, 1]:
            res += matrix_power(kernel[s], num_cat).min().round(4) - 0.0001
        return res

    # aperidicity min > 0
    def constraint2(self, x):
        seed = self.seed
        num_cat = self.num_cat
        kernel = get_kernel_from_policy(x, self.num_cat, self.seed, self.dynamics, self.estimation)
        res = 0
        for s in [0, 1]:
            res += min(kernel[s].diagonal()).round(4) - 0.0001
        return res

    # monotonicity of policy
    def constraint3(self, x):
        # return an array of size len(x)-1, where each element is the difference between the current and the previous element
        num_cat = self.num_cat
        Pd1 = x.reshape(2, num_cat)
        for s in [0, 1]:
            diff = Pd1[s][1:] - Pd1[s][:-1]
            # check if all elements are positive
            # print(diff.round(2))
            if np.any(diff < 0):
                return -1

        return 1

    def get_common_constraints(self):

        con1 = {'type': 'ineq', 'fun': self.constraint1}
        con2 = {'type': 'ineq', 'fun': self.constraint2}
        # con3 = {'type': 'ineq', 'fun': self.constraint3}

        # can be apended to the list of constraints
        return [con1, con2]

    def check_constraints(self, x, cons):
        i = 0
        violated_constraints = []
        for c in cons:
            if self.verbose:
                print("check constraints", i)
            if c["type"] == 'ineq':
                #             print("ineq", c["fun"](x))
                if c["fun"](x) < 0:
                    violated_constraints.append(i)
                    # print("oh, oh, something wrong with inequalities, con:", i)

            else:
                #             print("eq", c["fun"](x).round(4))
                if (c["fun"](x)).round(4) != 0:
                    violated_constraints.append(i)
                    # print("oh, oh, something wrong with equalities, con:", i)
            i += 1
        return violated_constraints


    def solve(self):
        num_cat = self.num_cat
        seed = self.seed

        # imported from op_utils
        x0 = get_Pd1_given(num_cat)
        Pd10 = get_Pd1_given(num_cat)
        if self.estimation == "_true":
            Py1x = get_Py1x_given(num_cat)
        else:
            Py1x = get_Py1x_estimated(self.estimation)
        Px = get_Px_given(num_cat)


        if self.verbose:
            print("Initial policy", Pd10.round(4))


        eq1 = np.matmul(Py1x[0].reshape(1, num_cat), Px[0])
        eq2 = np.matmul(Py1x[1].reshape(1, num_cat), Px[1])

        EOP = [0, 0]
        for s in [0, 1]:
            sum0 = 0
            sum1 = 0
            for i in range(num_cat):  # i is the x in Px
                sum0 +=(x0[s][i] * Py1x[s][i] * Px[s][i])
                sum1 +=(Py1x[s][i] * Px[s][i])
            EOP[s] += sum0/ sum1


        if self.verbose:

            print('Initial Qual Rates: ', eq1, eq2, "diff", (eq1[0]-eq2[0])**2)
            print('Initial EOP Rates: ', EOP[0], EOP[1], "diff", (EOP[0]-EOP[1])**2)

        if self.verbose:
            print('Initial SSE Objective: ' + str(self.objective(x0)))


        cons_common = self.get_common_constraints()
        cons_spec = self.get_specific_constraints()
        cons = cons_common + cons_spec
        # cons = cons_spec
        self.check_constraints(x0, cons)

        bounds = self.get_bounds()

        # print(x0)
        # print(bounds)
        # print(len(cons))
        #
        # 1/0

        x0 = x0.flatten()
        solution = minimize(self.objective, x0, method='SLSQP', \
                           bounds=bounds, constraints=cons, options={'disp': self.verbose, 'eps': self.eps, 'maxiter': 200})

        Pd1 = solution.x
        if self.verbose:
            print("Solution (pi)", Pd1.round(4))


        # sanity checks
        self.check_constraints(Pd1, cons)

        kernel = get_kernel_from_policy(Pd1, num_cat, seed, self.dynamics, self.estimation)
        assert np.allclose(np.sum(kernel, axis=2), 1), "not valid kernel"

        Px = np.empty((2, num_cat))
        for s in [0, 1]:
            Px[s] = get_stationary_dist_from_kernel(kernel[s])

        assert np.allclose(np.sum(Px, axis=1), 1), "PX not valid dist"

        # assert that Px is a stationary distribution by checking that it is a left eigenvector of the kernel
        Px2 = np.array([np.dot(kernel[0].T, Px[0]), np.dot(kernel[1].T, Px[1])])
        # diff = abs((Px - Px2))
        print(Px.round(3))
        print(Px2.round(3))
        # print(diff)
        assert np.allclose(Px, Px2, rtol=0.001), "PX not valid stationary dist"



        if self.verbose:
            print("Stat distribution (Px)", Px.round(4))
            print('Final SSE Objective (qual 0): ' + str(self.objective(Pd1)))

            # eq = [np.sum(Py1x[s] * Px[s]) for s in [0, 1]]

            eq = []
            for s in [0, 1]:
                Pxs = get_stationary_dist_from_kernel(kernel[s])
                res = np.matmul(Py1x[s].reshape(1, num_cat), Pxs)
                eq.append(res.flatten())
            print('Final Qual Rates: ', eq[0], eq[1], "diff", (eq[0]-eq[1]).round(4), "diff2", ((eq[0]-eq[1])**2).round(4) )

            Pd1 = Pd1.reshape(2, num_cat)
            EOP = [0, 0]
            for s in [0, 1]:
                sum0 = 0
                sum1 = 0
                for i in range(num_cat):  # i is the x in Px
                    sum0 += (Pd1[s][i] * Py1x[s][i] * Px[s][i])
                    sum1 += (Py1x[s][i] * Px[s][i])
                EOP[s] += sum0 / sum1

            print('Final EOP Rates: ', EOP[0], EOP[1], "diff", (EOP[0]-EOP[1]).round(4), "diff2", ((EOP[0]-EOP[1])**2).round(4) )

        # for downstairs analysis


        if self.verbose:
            print(Pd1)

        return Pd1
