import numbers
import sys

import matplotlib
import numpy as np
import pandas as pd
import os
from typing import List
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
from scipy.interpolate import griddata
from matplotlib import cm
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap

import sys

sys.path.append("../src")
from models.data_mask_estimators.OracleDataMaskerWithDeltaMinMax import OracleDataMaskerWithUniformDeltaMinMax, \
    OracleDataMaskerWithLeftSidedDeltaMinMax, OracleDataMaskerWithRightSidedDeltaMinMax, \
    OracleDataMaskerWithExtremeTailsDeltaMinMax, \
    OracleDataMaskerWithSmallTailsDeltaMinMax, OracleDataMaskerWithBetaUDeltaMinMax, \
    OracleDataMaskerWithBetaRightDeltaMinMax, \
    OracleDataMaskerWithBetaLeftDeltaMinMax, OracleDataMaskerWithBetaRightHillDeltaMinMax, \
    OracleDataMaskerWithBetaLeftHillDeltaMinMax


def is_number(x):
    return isinstance(x, numbers.Number)


def read_method_results_aux(folder_path, seeds=20, apply_mean=True, display_errors=False):
    df = pd.DataFrame()

    for seed in range(seeds):
        save_path = f"{folder_path}/seed={seed}.csv"
        try:
            seed_df = pd.read_csv(save_path).drop(['Unnamed: 0'], axis=1, errors='ignore')

            if 'coverage' in seed_df and abs(
                    seed_df['coverage'].item() - 0) < 0.01:
                # print(f"{folder_path}/seed={seed}.csv has 0 coverage")
                if np.isnan(seed_df['average length']).any():
                    print(
                        f"{folder_path}/seed={seed}.csv has invalid average length. the value is: {seed_df['average length'].item()}")
                    display(seed_df)
                    # print("got here")
                    continue
            if '(miscoverage streak) average length' in df.columns and \
                    np.isnan(seed_df['(miscoverage streak) average length']).any():
                print(
                    f"{folder_path}/seed={seed}.csv has invalid (miscoverage streak) average length")
                print("the value is: ", seed_df['(miscoverage streak) average length'].item())
                display(seed_df)

            df = pd.concat([df, seed_df], axis=0)
        except Exception as e:
            # print("got an exception")
            if display_errors:
                print(e)
    if len(df) == 0:
        # print(f"{folder_path} had 0 an error")
        save_path = f"{folder_path}/seed=0.csv"
        pd.read_csv(save_path).drop(['Unnamed: 0'], axis=1, errors='ignore')  # raises an exception
        raise Exception(f"could not find results in path {folder_path}")

    if apply_mean:
        df = df.apply(np.mean).to_frame().T

    return df


def get_method_display_name(method_name):
    if method_name in method_name_to_display_name:
        return method_name_to_display_name[method_name]
    else:
        # print(f"warning: does not know how to display {method_name}")
        return method_name


pd.set_option('display.max_columns', None)
imputations = ['linear', 'partially_linear', 'dml', 'full', 'full_with_linear', 'indep_partially_linear']
errors_methods = ['marginal', 'kmeans_clustering', 'linear_clustering', 'nn', 'CVAE', 'normal', 'gmm',
                  'rf', 'rfcde', 'nnkcde', 'flex_code', 'nf']
method_name_to_display_name = {}
base_calibrations = ['cqr', 'hps', 'aps']
base_models = [ 'qr', 'rf_qr', 'xgb_qr']
regressors = ['linear', 'full', 'full_with_linear', 'rf', 'xgb_reg']
# error_samples = ['marginal_error_sampler', 'linear_clustering_error_sampler', 'kmeans_clustering_error_sampler',
#                  'normal_error_sampler', 'gmm_error_sampler', 'rf_error_sampler', 'rfcde_error_sampler', 'nnkcde_error_sampler',
#                 'flex_code_error_sampler', ]
maskers =  [
            #                        'cnn_use_z=False', 'cnn_use_z=True',
            'network_use_z=False', 'network_use_z=True',
            'xgb_use_z=False', 'xgb_use_z=True',
            'rf_use_z=False', 'rf_use_z=True',
            'oracle']
for cal in base_calibrations:
    for model in base_models:
        for masker in maskers:
            method_name_to_display_name = {
                **method_name_to_display_name,
                f'{model}_Dummy': 'uncalibrated',
                f"{model}_CQR": 'naive cqr',
                f"{model}_{cal}": f'naive {cal}',
                f"{model}_{cal}_ignore_masked": f'{cal} ignore masked',
                f"{model}_oracle_{cal}": f'clean Y {cal} ',

                f"{model}_weighted_{cal}_{masker}_masker": f'weighted_{masker}_masker',
                f"{model}_two_staged_{cal}_{masker}_masker": f'two_staged_{masker}_masker',
                f"{model}_pcp_{cal}_{masker}_masker": f'pcp_{masker}_masker',
            }
        for regressor_method in regressors:
            for error_method in errors_methods:
                method_name_to_display_name = {
                    **method_name_to_display_name,
                    f"{model}_{regressor_method}_with_{error_method}_error_sampler_imputation_{cal}_calibration": f'{regressor_method}_with_{error_method}_errors_imputation',
                }
                for masker in maskers:
                    method_name_to_display_name = {
                        **method_name_to_display_name,
                        f"{model}_triply_robust_{regressor_method}_with_{error_method}_error_sampler_imputation_{cal}_calibration_pcp_cqr_{masker}_masker": f'triply_robust_{regressor_method}_with_{error_method}_errors_imputation_pcp_{masker}_masker',
                    }
            method_name_to_display_name = {
                **method_name_to_display_name,
                f"{model}_{regressor_method}_imputation_{cal}_calibration": f'{regressor_method}_imputation',
            }

for model in base_models:
    for impuation in imputations:
        imputation_display = impuation.replace("full_with_linear", "full + linear")
        method_name_to_display_name[f'{model}_{impuation}_imputation_cqr_calibration'] = imputation_display
        for error_method in errors_methods:
            method_name_to_display_name[
                f'{model}_{impuation}_with_{error_method}_error_sampler_imputation_cqr_calibration'] = \
                f"{imputation_display} with {error_method} errors"
            for masker in maskers:
                method_name_to_display_name[
                    f'{model}_triply_robust_{impuation}_with_{error_method}_error_sampler_imputation_cqr_calibration_pcp_cqr_{masker}_masker'] = \
                     f"triply robust {imputation_display} with {error_method} errors pcp {masker} masker"

method_name_to_display_name[f'qr_gmm_sample_imputator_imputation_cqr_calibration'] = 'none with gmm'

for key in method_name_to_display_name.keys():
    method_name_to_display_name[key] = method_name_to_display_name[key].replace("_", " ").lower()

methods = list(method_name_to_display_name.keys())


# methods

def read_method_results(base_path: str, dataset_name: str, method_name: str, seeds=20, apply_mean=True,
                        display_errors=False):
    full_folder_path = os.path.join(base_path, dataset_name, method_name)
    df = read_method_results_aux(full_folder_path, seeds, apply_mean, display_errors=display_errors)
    return df


def read_methods_results(base_path: str, dataset_name: str, method_names: List[str], seeds=20, display_errors=False,
                         apply_mean=True,
                         tqdm=False):
    total_df = pd.DataFrame()
    if tqdm:
        import tqdm
        method_names = tqdm.tqdm(method_names)
    for method_name in method_names:
        try:
            full_folder_path = os.path.join(base_path, dataset_name, method_name)
            df = read_method_results_aux(full_folder_path, seeds, apply_mean=apply_mean, display_errors=display_errors)
            df['Method'] = get_method_display_name(method_name)
            # df['method_name'] = method_name

            total_df = pd.concat([total_df, df])
        except Exception as e:
            if display_errors:
                print(f"got error while trying to read method {method_name}. error: {e}")

    return total_df


import re


def method_to_error_type(method):
    if "error" not in method:
        return 'none'
    if 'marginal' in method or 'with errors' in method:
        return 'marginal'
    elif 'kmeans' in method and 'use x=true' in method:
        return 'kmeans clustering with x'
    elif 'kmeans' in method:
        return 'kmeans clustering'
    elif 'linear clustering use x=true' in method:
        return 'linear clustering with x'
    elif 'linear clustering' in method:
        return 'linear clustering'
    elif 'nnkcde' in method:
        return 'nnkcde'
    elif 'flex' in method:
        return 'flex_code'
    elif 'gmm' in method:
        return 'gmm'
    elif 'rfcde' in method:
        return 'rfcde'
    elif 'normal' in method:
        return 'normal'
    elif 'nf' in method:
        return 'nf'
    elif 'rf' in method:
        return 'rf'
    elif 'cvae' in method:
        return 'cvae'
    elif 'nn' in method:
        return 'qr'
    else:
        raise Exception(f"don't know how to handle with method: {method}")


def method_to_display_name(method):
    errors_txt = re.search(r'with.*errors', method)
    if errors_txt is None:
        display_name = method.replace("imputation", "")
    else:
        display_name = method.replace(errors_txt.group(), "").replace("imputation", "")

    # errors_txt = re.search(r'with.*error_sampler', method)
    # if errors_txt is not None:
    #     display_name = method.replace(errors_txt.group(), "").replace("_imputation", "")

    while display_name.endswith(' '):
        display_name = display_name[:-1]

    return display_name


def process_methods_df(total_df):
    if len(total_df) == 0:
        raise Exception("no data")
    cols = list(total_df.columns)
    for col in cols:
        if 'coverage' in col:
            total_df[col] *= 100
        if is_number(total_df[col].iloc[0]):
            total_df[col] = np.round(total_df[col], 2)

    return total_df


def process_methods_names(total_df):
    total_df = total_df.copy()
    if len(total_df) == 0:
        raise Exception("no data")
    total_df['Error'] = total_df['Method'].apply(method_to_error_type)
    total_df['Method'] = total_df['Method'].apply(method_to_display_name)
    return total_df


def load_results(seeds, results_base_path, dataset_name, method_names):
    all_dfs = []
    try:
        data_df = read_methods_results(results_base_path, dataset_name, method_names, apply_mean=False,
                                       seeds=seeds, display_errors=False)
        data_df = process_methods_df(data_df)
        data_df = data_df.assign(Dataset=dataset_name)
        all_dfs.append(data_df)
    except Exception as e:
        print(f"data: {dataset_name}, error: {e}")
    total_df = pd.concat(all_dfs)
    return total_df


def load_constant_error_df(seeds, results_base_path, delta_exp_method_name, dataset_name: str):
    assert delta_exp_method_name == 'wcp' or delta_exp_method_name == 'pcp'
    delta_exp_method = 'weighted' if delta_exp_method_name == 'wcp' else 'pcp'
    delta_methods = os.listdir(os.path.join(results_base_path, dataset_name))
    delta_methods = [m for m in delta_methods if f'qr_{delta_exp_method}_cqr_oracle_with_delta' in m]
    total_df = load_results(seeds, results_base_path, dataset_name, delta_methods)
    m_to_d = lambda m: np.round(float(m.split('delta=')[1].split("_")[0]), 2)
    total_df['Delta'] = total_df['Method'].apply(m_to_d)
    return total_df


def display_constant_error_figure(df: pd.DataFrame, figure_path: str, delta_exp_method_name: str,
                                 is_cp_overcoverage: bool):
    print("Naive CP achieves under coverage")
    # dataset_name = 'missing_y_regression_synthetic_z3'
    delta_exp_method_title = r'\texttt{WCP}' if delta_exp_method_name == 'wcp' else r'\texttt{PCP}'
    sns.set(rc={'figure.figsize': (11, 4)})
    sns.set(font_scale=3)
    sns.lineplot(data=df, x='Delta', y='full y2 coverage',
                 label=f'{delta_exp_method_title} w/ approx. weights', linewidth=6)

    plt.ylabel("Coverage")
    plt.xlabel(r"$\delta$")
    plt.xlim(-2, 2)
    plt.ylim(0, 105)
    if is_cp_overcoverage:
        plt.axvspan(0, 5, color='green', alpha=0.3)
        plt.axvspan(-5, -1, color='green', alpha=0.3)
        plt.axvspan(-1, 0, color='orange', alpha=0.3)
        save_name = 'Naive_CP_achieves_over_coverage.png'
        plt.title("Naive CP achieves over coverage")
    else:
        plt.axvspan(0, 5, color='orange', alpha=0.3)
        plt.axvspan(-5, -1, color='orange', alpha=0.3)
        plt.axvspan(-1, 0, color='green', alpha=0.3)
        save_name = 'Naive_CP_achieves_under_coverage.png'
        plt.title("Naive CP achieves under coverage")

    plt.axhline(y=90, color='r', linestyle='--', label=r'$1-\alpha=90\%$', linewidth=3)
    plt.legend()
    # plt.axvline(x=-1, color='g', linestyle='--')
    # save_dir = os.path.join("figures", "delta_exp", delta_exp_method, '1d')
    save_dir = os.path.join(figure_path, "w_delta", delta_exp_method_name, '1_dim')
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, save_name), dpi=300, bbox_inches='tight')
    plt.show()


def display_noise_distribution(paper_figures_path):
    params = {
        'data_scaler': None,
        'data_masker': None,
        'dataset_name': None,
        'x_dim': 1,
        'z_dim': 1,
        'delta_min': -1,
        'delta_max': 1
    }

    noise_functions = [
        OracleDataMaskerWithUniformDeltaMinMax(**params),
        OracleDataMaskerWithLeftSidedDeltaMinMax(**params),
        OracleDataMaskerWithRightSidedDeltaMinMax(**params),
        OracleDataMaskerWithExtremeTailsDeltaMinMax(**params),
        OracleDataMaskerWithSmallTailsDeltaMinMax(**params),
        OracleDataMaskerWithBetaUDeltaMinMax(**params),
        OracleDataMaskerWithBetaRightDeltaMinMax(**params),
        OracleDataMaskerWithBetaLeftDeltaMinMax(**params),
        OracleDataMaskerWithBetaRightHillDeltaMinMax(**params),
        OracleDataMaskerWithBetaLeftHillDeltaMinMax(**params)
    ]

    for noise_function in noise_functions:
        noise = noise_function.generate_noise(size=200000, device='cpu').numpy()
        plt.figure(figsize=(5, 4))
        plt.hist(noise, density=True, bins=5)
        plt.xlabel(r"$\delta$")
        plt.ylabel("Density")
        plt.ylim(0, 1)
        plt.title(noise_function.display_name)
        save_dir = os.path.join(paper_figures_path, "w_delta", "error_distribution")
        os.makedirs(save_dir, exist_ok=True)
        plt.savefig(os.path.join(save_dir, f"{noise_function.display_name.replace(' ', '_')}.png"), dpi=300,
                    bbox_inches='tight')
        plt.show()
color_palette = {
                "uncalibrated": "hotpink",
                "Uncalibrated": "hotpink",
                'Naive CP': "r",
                 'Naive CP\n(clean + noisy)': "r",
                 'Naive CP\n(only clean)': "r",
                 'Infeasible WCP': "orange",
                 'WCP': "orange",
                 'PCP': "b",
                 'PCP (est. weights)': "b",
                 'PCP (oracle weights)': "purple",

                 'Uncertain Imputation': "g",
                 'Uncertain Imp.': "g",
                 'Ours': "g",
                'Naive Imputation': 'sienna', # sienna, brown, y
                'Naive Imp.': 'sienna', # sienna, brown, y,
                'Triply': 'orange',
                'Triply Robust': 'orange',
                'TriplyRobust': 'orange',
                'Triply Robust (PCP est. weights)': 'orange',
                'Triply Robust (PCP oracle weights)': 'g',
                 }




def display_pcp_fail_results(dataset_name: str, results_base_path: str, figure_path:str, seeds: int):
    curr_method_name_to_display_name = {
        'naive cqr': 'Naive CP',
        'pcp ': 'PCP',
        'full + linear with linear clustering errors': 'Uncertain Imp.',
    }
    curr_methods_order = list(curr_method_name_to_display_name.values())

    total_df = read_methods_results(results_base_path, dataset_name, methods, apply_mean=False, seeds=seeds,
                                    display_errors=False)
    total_df = process_methods_df(total_df)

    methods_to_keep = ['naive cqr', 'pcp network', 'full + linear']
    methods_to_exclude = [ 'triply']

    def keep_method(method_name):
        return any([a in method_name for a in methods_to_keep]) and not any(
            [a in method_name for a in methods_to_exclude])

    masker_name_to_display_name = {
        'oracle masker': 'oracle',
        'network use z=true masker': "",
    }

    def get_masker_from_method_name(method_name):
        #     print(method_name) network use z=true masker
        for k in masker_name_to_display_name:
            if k in method_name:
                return masker_name_to_display_name[k]
        return 'none'

    def remove_masker_name_from_method_name(method_name):
        masker = get_masker_from_method_name(method_name)
        for k in masker_name_to_display_name:
            method_name = method_name.replace(k, "")
        if masker != 'none':
            method_name += masker
        return method_name

    total_df = total_df[total_df['Method'].apply(keep_method)]
    total_df['Masker'] = total_df['Method'].apply(lambda x: get_masker_from_method_name(x))
    total_df['Method'] = total_df['Method'].apply(lambda x: remove_masker_name_from_method_name(x))
    total_df = total_df[total_df['Method'].apply(lambda m: m in curr_method_name_to_display_name)]
    total_df['Method'] = total_df['Method'].apply(lambda x: curr_method_name_to_display_name[x])

    plt.figure(figsize=(5.5, 2))
    sns.set(font_scale=1.5)

    sns.boxplot(data=total_df, x='Method', hue='Method', y='full y2 coverage', order=curr_methods_order,
                palette=color_palette, legend=False)
    plt.ylabel("Coverage")
    plt.xlabel("")
    plt.axhline(y=90, color='r', linestyle='--')
    save_dir = os.path.join(figure_path, "pcp_fail")
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "coverage.png"), dpi=300, bbox_inches='tight')
    plt.show()
    plt.figure(figsize=(5.5, 2))
    sns.boxplot(data=total_df, x='Method', hue='Method', y='y2 length', order=curr_methods_order,
                palette=color_palette, legend=False)
    plt.ylabel("Interval length")
    plt.xlabel("")
    plt.savefig(os.path.join(save_dir, "length.png"), dpi=300, bbox_inches='tight')
    plt.show()
