import os
import shutil
import math
import torch as th
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from matplotlib.patches import *
from matplotlib.lines import Line2D
from tueplots import bundles, fontsizes
plt.rcParams.update(bundles.neurips2024())
fontsizes.neurips2024()

base_dir = 'output/effect_estimation'
scms = ['fairness', 'fairness_xw', 'fairness_xy', 'fairness_yw']
models = ['maf', 'nice']
methods = ['exom', 'gan_ncm', 'rs', 'rs_gan_ncm']
qs = ['ate', 'ett', 'nde', 'ctfde']
seeds = [0, 7, 42, 3407, 65535]


def get_log(method, scm, model, q, seed, seed2):
    if method == 'exom':
        sub_dir = f'm={model},q={q},c=e+t.w_e.w_t,k=mb1.mb1.em,r=attn,t=5,h=64x2'
    elif method == 'gan_ncm':
        sub_dir = f'm={model},q={q},c=e+t.w_e.w_t,k=mb1.mb1.em,r=attn,t=5,h=64x2'
    elif method == 'rs':
        sub_dir = f'q={q}'
    elif method == 'rs_gan_ncm':
        sub_dir = f'q={q}'
    sub_dir = f'{method}/{scm}/' + sub_dir
    if seed > 0:
        sub_dir += f',{seed}'
    if seed2 > 0:
        sub_dir += f',cfs={seed2}'
    res = th.load(os.path.join(base_dir, sub_dir, 'logs/effect_logs.pt'))
    return res


def df_csv_o(path):
    if os.path.exists(path):
        return pd.read_csv(path)
    data = {
        'scm': [],
        'method': [],
        'model': [],
        'seed': [],
        'q': [],
        'estimate': [],
    }

    def add_log(scm, method, model):
        for q in qs:
            for s in seeds:
                log = get_log(method, scm, model, q, s, 0)
                for estimate in log:
                    data['scm'].append(scm)
                    data['method'].append(method)
                    data['model'].append(model)
                    data['seed'].append(s)
                    data['q'].append(q)
                    data['estimate'].append(estimate)

    for scm in scms:
        for method in ['rs', 'exom']:
            if method == 'rs':
                add_log(scm, method, '-')
                continue
            for model in models:
                add_log(scm, method, model)

    df = pd.DataFrame.from_dict(data)
    df.to_csv(path)
    return df


dfo = df_csv_o('script/figure/effect_estimation_original.csv')


def df_csv_p(path):
    if os.path.exists(path):
        return pd.read_csv(path)
    data = {
        'scm': [],
        'method': [],
        'model': [],
        'seed': [],
        'seed2': [],
        'q': [],
        'estimate': [],
    }

    def add_log(scm, method, model):
        for q in qs:
            for s in seeds:
                for s2 in seeds:
                    log = get_log(method, scm, model, q, s, s2)
                    for estimate in log:
                        data['scm'].append(scm)
                        data['method'].append(method)
                        data['model'].append(model)
                        data['seed'].append(s)
                        data['seed2'].append(s2)
                        data['q'].append(q)
                        data['estimate'].append(estimate)

    for scm in scms:
        for method in ['rs_gan_ncm', 'gan_ncm']:
            if method == 'rs_gan_ncm':
                add_log(scm, method, '-')
                continue
            for model in models:
                add_log(scm, method, model)

    df = pd.DataFrame.from_dict(data)
    df.to_csv(path)
    return df


dfp = df_csv_p('script/figure/effect_estimation_proxy.csv')


def select(df: pd.DataFrame, scm, q, method, model=None, seed2=None):
    if seed2 is None:
        if model is None:
            return df[(df['scm'] == scm) &
                      (df['q'] == q) &
                      (df['method'] == method)
                      ][['seed', 'estimate']]
        else:
            return df[(df['scm'] == scm) &
                      (df['q'] == q) &
                      (df['method'] == method) &
                      (df['model'] == model)
                      ][['seed', 'estimate']]
    else:
        if model is None:
            return df[(df['scm'] == scm) &
                      (df['q'] == q) &
                      (df['method'] == method) &
                      (df['seed2'] == seed2)
                      ][['seed', 'estimate']]
        else:
            return df[(df['scm'] == scm) &
                      (df['q'] == q) &
                      (df['method'] == method) &
                      (df['model'] == model) &
                      (df['seed2'] == seed2)
                      ][['seed', 'estimate']]


def to_np(df: pd.DataFrame):
    X = df.to_numpy()
    n = 1
    X = X[:, -1].reshape(n, -1)
    return X


def ci_95(X: np.ndarray):
    Y = X.copy()
    for i in range(X.shape[-1]):
        Y[:, i] = X[:, i]
    return 2 * Y.reshape(-1).std()


def table(scms=['fairness', 'fairness_xw', 'fairness_xy', 'fairness_yw']):
    data = {
        'item': [],
        'scm': [],
        'q': [],
        'ci95': [],
    }

    def add_data(item, scm, q, ci95):
        data['item'].append(item)
        data['scm'].append(scm)
        data['q'].append(q)
        data['ci95'].append(ci95)

    def bias_o(model: str):
        for scm in scms:
            for q in qs:
                if model == 'rs':
                    Y = select(dfo, scm, q, 'rs')
                else:
                    Y = select(dfo, scm, q, 'exom', model=model)
                Y = to_np(Y)
                Y = Y[(Y <= 1) & (Y >= -1)].reshape(-1, 1)
                x = ci_95(Y)
                add_data(f'{model}_o', scm, q, x)

    def bias_p(model: str):
        for scm in scms:
            for q in qs:
                x1 = []
                for s2 in seeds:
                    if model == 'rs':
                        Y = select(dfp, scm, q, 'rs_gan_ncm', seed2=s2)
                    else:
                        Y = select(dfp, scm, q, 'gan_ncm',
                                   model=model, seed2=s2)
                    if len(Y) > 0:
                        Y = to_np(Y)
                        Y = Y[(Y <= 1) & (Y >= -1)].reshape(-1, 1)
                        x1.append(ci_95(Y))
                if len(x1) == 0:
                    continue
                x = np.mean(x1)
                add_data(f'{model}_p', scm, q, x)

    bias_o('rs')
    bias_p('rs')
    bias_o('maf')
    bias_p('maf')
    bias_o('nice')
    bias_p('nice')

    return pd.DataFrame.from_dict(data)


table().to_csv('script/figure/effect_ci95.csv')


tmpl = """
\\label{{tab:2}}
\\centering
\\begin{{tabular}}{{ccccccccc}}
    \\toprule
    \\multicolumn{{2}}{{c}}{{}} & \\multicolumn{{3}}{{c}}{{SIMPSON-NLIN}} & \\multicolumn{{4}}{{c}}{{FAIRNESS}}\\\\
    \\cmidrule(r){{3-5}} \\cmidrule(r){{6-9}}
    Method & SCM & $|s|=1$ & $|s|=3$ & $|s|=5$ & ATE & ETT & NDE & CtfDE\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{RS}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[MAF]}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[NICE]}} & O & {} & {} & {} & {} & {} & {} & {}\\\\
    & P & {} & {} & {} & {} & {} & {} & {}\\\\
    \\bottomrule
\\end{{tabular}}
"""


def tab2():
    if not os.path.exists('script/figure/density_ci95.csv'):
        return
    df1 = pd.read_csv('script/figure/denstiy_ci95.csv')
    df2 = pd.read_csv('script/figure/effect_ci95.csv')
    values = []
    for method in ['rs', 'maf', 'nice']:
        for i in ['o', 'p']:
            for j in [1, 3, 5]:
                scm = 'simpson_nlin'
                result = df1[(df1['item'] == f'{method}_{i}') &
                             (df1['j'] == j) &
                             (df1['scm'] == scm)][['ci95']]
                val = result.to_numpy().item()
                if val == math.nan:
                    values.append('-')
                else:
                    val = float(val)
                    values.append('$\pm{:0.3f}$'.format(val))
            for q in ['ate', 'ett', 'nde', 'ctfde']:
                scm = 'fairness'
                result = df2[(df2['item'] == f'{method}_{i}') &
                             (df2['q'] == q) &
                             (df2['scm'] == scm)][['ci95']]
                val = float(val)
                if str(val) == 'nan':
                    values.append('-')
                else:
                    values.append('$\pm{:0.3f}$'.format(val))
    tab2 = tmpl.format(*values)
    with open('script/figure/tabs/tab2.tex', 'w+', encoding='utf-8') as f:
        f.write(tab2)


tab2()


tmpl = """
\\label{{tab:92}}
\\centering
\\begin{{tabular}}{{cccccccccc}}
    \\toprule
    \\multicolumn{{2}}{{c}}{{}} & \\multicolumn{{4}}{{c}}{{FAIRNESS}} & \\multicolumn{{2}}{{c}}{{FAIRNESS-XW}}\\\\
    \\cmidrule(r){{3-6}} \\cmidrule(r){{7-10}}
    Method & SCM & ATE & ETT & NDE & CtfDE & ATE & ETT & NDE & CtfDE\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{RS}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[MAF]}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[NICE]}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\bottomrule
\\end{{tabular}}
"""


def tabc92():
    df1 = pd.read_csv('script/figure/effect_ci95.csv')
    values = []
    for method in ['rs', 'maf', 'nice']:
        for i in ['o', 'p']:
            for scm in ['fairness', 'fairness_xw']:
                for q in ['ate', 'ett', 'nde', 'ctfde']:
                    result = df1[(df1['item'] == f'{method}_{i}') &
                                 (df1['q'] == q) &
                                 (df1['scm'] == scm)][['ci95']]
                    s = '\pm {:0.3f}'.format(result.to_numpy().item())
                    values.append(s)
    tabc92 = tmpl.format(*values)
    with open('script/figure/tabs/tabc92.tex', 'w+', encoding='utf-8') as f:
        f.write(tabc92)


tabc92()


tmpl = """
\\label{{tab:93}}
\\centering
\\begin{{tabular}}{{cccccccccc}}
    \\toprule
    \\multicolumn{{2}}{{c}}{{}} & \\multicolumn{{3}}{{c}}{{FAIRNESS-XY}} & \\multicolumn{{2}}{{c}}{{FAIRNESS-YW}}\\\\
    \\cmidrule(r){{3-6}} \\cmidrule(r){{7-10}}
    Method & SCM & ATE & ETT & NDE & CtfDE & ATE & ETT & NDE & CtfDE\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{RS}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[MAF]}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\midrule
    \\multirow{{2}}{{*}}{{EXOM[NICE]}} & O & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    & P & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$ & ${}$\\\\
    \\bottomrule
\\end{{tabular}}
"""


def tabc93():
    df1 = pd.read_csv('script/figure/effect_ci95.csv')
    values = []
    for method in ['rs', 'maf', 'nice']:
        for i in ['o', 'p']:
            for scm in ['fairness_xy', 'fairness_yw']:
                for q in ['ate', 'ett', 'nde', 'ctfde']:
                    result = df1[(df1['item'] == f'{method}_{i}') &
                                 (df1['q'] == q) &
                                 (df1['scm'] == scm)][['ci95']]
                    s = '\pm {:0.3f}'.format(result.to_numpy().item())
                    values.append(s)
    tabc93 = tmpl.format(*values)
    with open('script/figure/tabs/tabc93.tex', 'w+', encoding='utf-8') as f:
        f.write(tabc93)


tabc93()
