import string
from copy import deepcopy
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xgboost as xgb
from numba import jit

INIT_DISTRIBUTION = np.array([0.5, 0.5])
TRANSITION = np.array([[0.85, 0.2], [0.15, 0.8]])

EMISSION = np.array([[0.0665, 0.0089], [0.0956, 0.0202], [0.0661, 0.0361], [0.1634, 0.007],
                     [0.1242, 0.0373], [0.0628, 0.0098], [0.0967, 0.0228], [0.076 , 0.032],
                     [0.0407, 0.0093], [0.0664, 0.0265], [0.0048, 0.042], [0.0109, 0.0341],
                     [0.0025, 0.0913], [0.0002, 0.0746], [0.0058, 0.0636], [0.0697, 0.0855],
                     [0.0107, 0.1423],[0.031 , 0.048], [0.001 , 0.1101],[0.005 , 0.0986]])

LIST_FEATURE_DUMMY = \
    [
        'INTERCEPT', 'EDUCATION_abc_GRAD_SCHOOL', 'EDUCATION_abc_UNIVERSITY', 'EDUCATION_abc_HIGH_SCHOOL', 'EDUCATION_abc_OTHERS',
        'MARRIAGE_abc_MARRIED',  'MARRIAGE_abc_SINGLE', 'MARRIAGE_abc_OTHERS',
        'AGE_GRP_abc_A', 'AGE_GRP_abc_B', 'AGE_GRP_abc_C', 'AGE_GRP_abc_D', 'AGE_GRP_abc_E',
        'RISK_SCORE_abc_A', 'RISK_SCORE_abc_B', 'RISK_SCORE_abc_C', 'RISK_SCORE_abc_D', 'RISK_SCORE_abc_E',
        'REVENUE_GRP_abc_A', 'REVENUE_GRP_abc_B', 'REVENUE_GRP_abc_C', 'REVENUE_GRP_abc_D',
        'ACTION_abc_call', 'ACTION_abc_email',
        'REVENUE_GRP_ACTION_abc_A_call', 'REVENUE_GRP_ACTION_abc_A_email',
        'REVENUE_GRP_ACTION_abc_D_call', 'REVENUE_GRP_ACTION_abc_D_email',
        'RISK_SCORE_ACTION_abc_A_call', 'RISK_SCORE_ACTION_abc_A_email',
        'RISK_SCORE_ACTION_abc_B_call', 'RISK_SCORE_ACTION_abc_B_email',
        'RISK_SCORE_ACTION_abc_D_call', 'RISK_SCORE_ACTION_abc_D_email',
        'RISK_SCORE_ACTION_abc_E_call', 'RISK_SCORE_ACTION_abc_E_email'
    ]

LIST_FEATURE_DUMMY_INVARIANT = \
    [
        'INTERCEPT', 'PROB_RALLY', 'EDUCATION_abc_GRAD_SCHOOL', 'EDUCATION_abc_UNIVERSITY', 'EDUCATION_abc_HIGH_SCHOOL', 'EDUCATION_abc_OTHERS',
        'MARRIAGE_abc_MARRIED',  'MARRIAGE_abc_SINGLE', 'MARRIAGE_abc_OTHERS',
        'AGE_GRP_abc_A', 'AGE_GRP_abc_B', 'AGE_GRP_abc_C', 'AGE_GRP_abc_D', 'AGE_GRP_abc_E',
        'RISK_SCORE_abc_A', 'RISK_SCORE_abc_B', 'RISK_SCORE_abc_C', 'RISK_SCORE_abc_D', 'RISK_SCORE_abc_E',
        'REVENUE_GRP_abc_A', 'REVENUE_GRP_abc_B', 'REVENUE_GRP_abc_C', 'REVENUE_GRP_abc_D'
    ]

LIST_FEATURE_DUMMY_VARIANT = [x for x in LIST_FEATURE_DUMMY if x not in LIST_FEATURE_DUMMY_INVARIANT]
LIST_FEATURE_DUMMY_CB = LIST_FEATURE_DUMMY_INVARIANT + [f'{x}_rally' for x in LIST_FEATURE_DUMMY_VARIANT] + \
                        [f'{x}_down' for x in LIST_FEATURE_DUMMY_VARIANT]

def combine_dt_dummy_belief(dt, belief):
    r, c = belief.shape
    r_dt, _ = dt.shape
    if r!= r_dt:
        dt_belief = pd.DataFrame(belief)
        dt_belief.columns = ['PROB_RALLY', 'PROB_DOWN']
        dt_belief['index_env'] = list(range(1, r+1))
        dt = dt.merge(dt_belief, how='left', on='index_env').reset_index(drop=True)
    else:
        dt = deepcopy(dt)
        dt['PROB_RALLY'] = belief[:, 0]

    dt_invar = dt[LIST_FEATURE_DUMMY_INVARIANT].reset_index(drop=True)
    dt_var_rally = dt[LIST_FEATURE_DUMMY_VARIANT].reset_index(drop=True)
    if r != r_dt:
        for var_ in LIST_FEATURE_DUMMY_VARIANT:
            dt_var_rally[var_] = dt_var_rally[var_] * dt['PROB_RALLY']
    else:
        for var_ in LIST_FEATURE_DUMMY_VARIANT:
            dt_var_rally[var_] *= belief[:, 0]

    dt_var_rally.columns = [f'{x}_rally' for x in dt_var_rally]

    dt_var_down = dt[LIST_FEATURE_DUMMY_VARIANT].reset_index(drop=True)
    if r != r_dt:
        for var_ in LIST_FEATURE_DUMMY_VARIANT:
            dt_var_down[var_] = dt_var_down[var_] * dt['PROB_DOWN']
    else:
        for var_ in LIST_FEATURE_DUMMY_VARIANT:
            dt_var_down[var_] *= belief[:, 1]

    dt_var_down.columns = [f'{x}_down' for x in dt_var_down]
    dt_final = pd.concat([dt_invar, dt_var_rally, dt_var_down], axis=1).reset_index(drop=True)
    return dt_final[LIST_FEATURE_DUMMY_CB]

def transform_dt_dummy(dt):
    dt = deepcopy(dt)
    dt['INTERCEPT'] = 1
    dt['REVENUE_GRP_ACTION'] = dt['REVENUE_GRP'] + '_' + dt['ACTION']
    dt['RISK_SCORE_ACTION'] = dt['RISK_SCORE'] + '_' + dt['ACTION']

    list_feature = [x for x in dt if 'index' not in x]
    df_dummy = pd.get_dummies(dt[list_feature], prefix_sep='_abc_')
    df_dummy = df_dummy * 1

    output = \
        pd.concat([dt[[x for x in dt if 'index' in x] + ['ACTION']], df_dummy[LIST_FEATURE_DUMMY]], axis= 1).reset_index(drop=True)
    return output


@jit
def compute_belief_numba(context:int, transition_matrix:np.ndarray, emission_matrix: np.ndarray, current_belief: np.ndarray):
    prob = emission_matrix[context, :] * transition_matrix @ current_belief
    current_belief = prob / prob.sum()
    return current_belief

def update_belief_numba(transition_matrix:np.ndarray, emission_matrix: np.ndarray, vec_init:np.ndarray, list_contexts: List):
    list_hist_belief = []
    current_belief = vec_init
    for context in list_contexts:
        current_belief = compute_belief_numba(context, transition_matrix, emission_matrix, current_belief)
        list_hist_belief.append(current_belief)
    return list_hist_belief

def load_xgb(MODEL_PATH, FEATURES_PATH, CORES=5):
    models = xgb.Booster({"nthread":CORES})
    models.load_model(MODEL_PATH)
    features = list(np.load(FEATURES_PATH, allow_pickle = True))
    models.feature_names = features
    return models

def predict_xgb(DATA, MODEL, PREFIX_SEP = "_abc_"):
    dt_dummy = pd.get_dummies(DATA, prefix_sep=PREFIX_SEP)
    features = MODEL.feature_names

    features_ = [x for x in features if x not in dt_dummy.columns]

    for var_ in features_:
        dt_dummy[var_] = 0

    dt_dummy = dt_dummy[features]

    xgb_data = xgb.DMatrix(data=dt_dummy.values,feature_names=features)

    return MODEL.predict(xgb_data)

def quantile_cut(dt, var, bin = 5):
    step = int(100/bin)
    range_cut = list(range(step, 100 + step, step))

    index_num = string.ascii_lowercase
    min_ = -np.inf
    for counter_, value_ in enumerate(range_cut):
        header_ = index_num[counter_]
        if value_ == 100:
            max_ = np.inf
        else:
            max_ = int(np.round(np.percentile(a=dt[var], q=value_)))
        title_ = f"{header_}.({str(min_)}, {str(max_)}]"
        index_ = (dt[var] > min_) & (dt[var]  <= max_)
        dt.loc[index_, f"{var}_CLUSTER"] = title_
        min_ = max_

def plot_performance_same_strategy(dict_results, T_range, prefix_name, sample_size, log_transform=False,
                                   keep_legend=True, ax=None, sampling=None):
    marker_position = -5000

    if sampling is not None:
        step = int(1 / sampling)
        index_choose = [x - 1 for x in T_range if (x % step) == 0]
        marker_position = int(marker_position * sampling)

    if ax is None:
        pass_ax = False
    else:
        pass_ax = True

    list_color= ["blue", "steelblue", "deepskyblue", "brown", "saddlebrown",
                 "sandybrown", "darkorange", "limegreen", "green", "darkgreen"]

    list_name_strategy = [x for x in dict_results.keys() if (prefix_name in x) & ("std" not in x)]

    list_marker = ["^"] * 10

    for index_, name_ in enumerate(list_name_strategy):
        color_ = list_color[index_]
        marker_ = list_marker[index_]
        T_range_ = deepcopy(T_range)
        y_ = deepcopy(dict_results[name_])
        std_ = deepcopy(dict_results[f"{name_} - std"])

        ci_ = 2 * (std_/np.sqrt(sample_size))

        upper_ = y_ + ci_
        lower_ = y_ - ci_

        if log_transform is True:
            y_ = np.log(y_ + 1)
            lower_ = np.log(lower_ + 1)
            upper_ = np.log(upper_ + 1)

        if index_ == 0:
            if pass_ax is False:
                fig, ax = plt.subplots(figsize=(16, 12))

        if sampling is not None:
            T_range_ = T_range[index_choose]
            y_ = y_[index_choose]
            lower_ = lower_[index_choose]
            upper_ = upper_[index_choose]

        ax.plot(T_range_, y_, marker=marker_, color=color_, label=name_, markevery=[marker_position])
        ax.fill_between(T_range_, lower_, upper_, color=color_, alpha=.1)
        if keep_legend is True:
            ax.legend(loc='upper left', prop={'size': 8})

    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)

    if np.max(y_[-1]) <= 0:
        ax.set_ylim([0, 1000])
    else:
        ax.set_ylim([0, 1000])

    if pass_ax is False:
        return fig, ax

def plot_performance_benchmark(dict_results, T_range, sample_size,
                               log_transform = False, keep_legend = True, ax = None, sampling = None):
    marker_position = -5000

    if sampling is not None:
        step = int(1/sampling)
        index_choose = [x-1 for x in T_range if (x%step)==0]
        marker_position = int(marker_position * sampling)


    if ax is None:
        pass_ax = False
    else:
        pass_ax = True

    list_color_linear = ["limegreen"]
    list_color_belief = ["blue"]

    list_linear = [x for x in dict_results.keys() if ("Belief" not in x) & ("std" not in x)]
    list_belief = [x for x in dict_results.keys() if ("with Belief" in x) & ("std" not in x)]

    list_name_strategy = list_linear + list_belief

    list_color = list_color_linear + list_color_belief

    list_marker = ["^"]*1 + ["D"]*1

    for index_, name_ in enumerate(list_name_strategy):
        color_ = list_color[index_]
        marker_ = list_marker[index_]
        T_range_ = deepcopy(T_range)
        y_ =  deepcopy(dict_results[name_])
        std_ = deepcopy(dict_results[f"{name_} - std"])

        ci_ = 2 * (std_ / np.sqrt(sample_size))

        upper_ = y_ + ci_
        lower_ = y_ - ci_

        if log_transform is True:
            y_ = np.log(y_ + 1)
            lower_ = np.log(lower_ + 1)
            upper_ = np.log(upper_ + 1)

        if index_ == 0:
            if pass_ax is False:
                fig, ax = plt.subplots(figsize=(16, 12))

        if sampling is not None:
            T_range_ = T_range[index_choose]
            y_ = y_[index_choose]
            lower_ = lower_[index_choose]
            upper_ = upper_[index_choose]

        ax.plot(T_range_, y_, marker = marker_, color = color_, label = name_, markevery=[marker_position])
        ax.fill_between(T_range_, lower_, upper_, color=color_, alpha=.1)
        if keep_legend is True:
            ax.legend(loc='upper left', prop={'size': 8})

    ax.tick_params(axis='x', labelsize=16)
    ax.tick_params(axis='y', labelsize=16)

    if np.max(y_[-1])<=0:
        ax.set_ylim([0, 1000])
    else:
        ax.set_ylim([0, 1000])

    if pass_ax is False:
        return fig, ax
