import numpy as np
from scipy.optimize import minimize, NonlinearConstraint, LinearConstraint, Bounds
from numpy.linalg import matrix_power
import sys
sys.path.append('../')
import os
import warnings
warnings.filterwarnings("ignore")



import sys
sys.path.append('../')

import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

from simulation.criteria import CriteriaData



## Parse Data
# https://github.com/fairmlbook/fairmlbook.github.io/blob/master/code/creditscore/creditscore.ipynb


# Adjust these as needed
# DATA_DIR='../FICO_data/'
DATA_DIR='FICO_data/'
DATA_DIR2='../FICO_data/'
# FIGURE_DIR='../../figures/'
FIGURE_DIR='.'

def cleanup_frame(frame):
    """Rename and re-order columns"""
    frame = frame.rename(columns={'Non- Hispanic white': 'White'})
    frame = frame.reindex(['Asian', 'Black', 'Hispanic', 'White'], axis=1)
    return frame

def read_totals(files):
    """Read the total number of people of each race"""
    # if the following command throws an error then try the command with DATA_DIR2
    try:
        frame = cleanup_frame(pd.read_csv(DATA_DIR + files['overall'], index_col=0))
    except Exception as e:
        # print("An error occurred:", e)
        frame = cleanup_frame(pd.read_csv(DATA_DIR2 + files['overall'], index_col=0))
    return {r:frame[r]['SSA'] for r in frame.columns}

def parse_data(files):
    """Parse sqf data set."""
    try:
        cdfs = cleanup_frame(pd.read_csv(DATA_DIR+files['cdf_by_race'], index_col=0))
        performance = 100 - cleanup_frame(pd.read_csv(DATA_DIR + files['performance_by_race'], index_col=0))
    except Exception as e:
        cdfs = cleanup_frame(pd.read_csv(DATA_DIR2 + files['cdf_by_race'], index_col=0))
        performance = 100 - cleanup_frame(pd.read_csv(DATA_DIR2 + files['performance_by_race'], index_col=0))
    return (cdfs/100., performance/100.)




def get_Px_pdf_Py1x_pdf():
    files = dict(cdf_by_race='transrisk_cdf_by_race_ssa.csv',
                 performance_by_race='transrisk_performance_by_race_ssa.csv',
                 overall='totals.csv')

    data_pair = parse_data(files)
    totals = read_totals(files)

    data = CriteriaData(data_pair[0], data_pair[1], totals)
    Px_pdf = data.pdfs
    Py1x_pdf = data.performance

    return Px_pdf, Py1x_pdf, totals

def get_Py1_new(Px, Py1x, num_cat):
    # P(Y=1|S) = \sum_x P(Y=1|X=x, S=s)P(X=x|S=s)
    alpha = []
    for s in [0, 1]:
        Py1 = 0
        for x in range(num_cat):
            Py1 += Py1x[s][x] * Px[s][x]
        alpha.append(Py1)
    return alpha

def get_Ps_given():
    _, _, Ps = get_Px_pdf_Py1x_pdf()
    Ps = {k: Ps[k] for k in ('Black', 'White')}
    # get the relative number of people from teh absolute numbers in ps
    Ps = {k: (Ps[k]/sum(Ps.values())).round(4) for k in Ps.keys()}
    # transform dict into array
    Ps = np.array(list(Ps.values()))
    return Ps

def get_Px_given(num_cat = 4):
    Px_pdf, _, _ = get_Px_pdf_Py1x_pdf()
    # group in 4 bins
    Px_pdf['score_bin'] = pd.cut(Px_pdf.index, num_cat, labels=range(num_cat))
    Px_pdf = Px_pdf.groupby('score_bin').sum()

    # assert that the first column of df2 sums to 1
    assert Px_pdf.iloc[:,0].sum() == 1
    assert Px_pdf.iloc[:,1].sum() == 1

    Px = Px_pdf[['Black', 'White']]
    Px = np.array(Px).T
    assert np.allclose(np.sum(Px, axis=1), 1)

    return Px


def get_Py1x_given(num_cat=4):
    _, Py1x_cdf, _ = get_Px_pdf_Py1x_pdf()

    # Group Pyxs in 4 bins, for each bin take the value of the last row
    Py1x_cdf = Py1x_cdf.groupby(pd.cut(Py1x_cdf.index, num_cat)).last()

    # rename the index of df 3 into numbers starting at 0
    Py1x_cdf.index = range(num_cat)

    Py1x = Py1x_cdf[[ 'Black', 'White']]
    Py1x = np.array(Py1x).T
    return Py1x

def get_Py1x_estimated(estimation):
    if estimation == '_est-random':
        # print(os.getcwd())
        return np.load(f'{os.getcwd()}/estimations/Py1x_est_random.npy')
    elif estimation == "_est-threshold":
        return np.load(f'{os.getcwd()}/estimations/Py1x_est_threshold.npy')
    elif estimation == "_est-biased":
        return np.load(f'{os.getcwd()}/estimations/Py1x_est_biased.npy')
    else:
        raise NotImplementedError