import os
import numpy as np
import plotly.io as pio
import plotly.graph_objects as go
from plotly.colors import qualitative
from utils.orchestrator import Evaluator


# parameters
data_name = 'recl'      # recl, rwth, rtraffic, rillness, rett
n_repeats = 5           # number of samples for the filtering
batch_size = 4096       # batch size to use to run through the entire dataset
model = 'dm'            # dm, moment, timer
n_bins = 4000           # number of histogram bins
xlims = (0.0, 0.3)      # xlims of histogram (relation error), set to None to not set lims
save = True
vis = True


# check model
allowed_model = ['dm', 'moment', 'timer']
model_not_implemented_error = f'model={model} not in {allowed_model}.'
assert model in allowed_model, model_not_implemented_error

# init evaluator
expt_name = data_name
evaluator = Evaluator(expt_name)

# general init
tasks = evaluator.allowed_tasks
save_path = os.path.join('expt', 'logs', data_name, 'quartile')
os.makedirs(save_path, exist_ok=True)

# get train set quartiles
combined_errors = []
for task in evaluator.allowed_tasks:
    ce, _, _ = evaluator.compute_ce_and_re_dm(task, batch_size, mode='train')
    combined_errors += ce.tolist()
combined_errors = np.array(combined_errors)
quartiles = np.percentile(combined_errors, [25, 50, 75])

# compute overlap coefficient 'n_repeats' times - also generate figures
overlap_coefficients = []
for i in range(n_repeats):
    figs = []
    print(f'\nrepeat = {i+1} / {n_repeats} -----------------------------------------------------------------------')
    save_name = f'{model}_{i+1}'

    # compute combined errors and relation errors over all the tasks
    combined_errors, relation_errors = [], []
    for task in evaluator.allowed_tasks:
        if model == 'dm':
            ce, re, _ = evaluator.compute_ce_and_re_dm(task, batch_size)
        elif model == 'moment':
            context_len = 24
            ce, re, _ = evaluator.compute_ce_and_re_moment(task, batch_size, context_len, deterministic=False)
        elif model == 'timer':
            ce, re, _ = evaluator.compute_ce_and_re_timer(task, batch_size, deterministic=False)
        else:
            raise NotImplementedError(model_not_implemented_error)
        combined_errors += ce.tolist()
        relation_errors += re.tolist()
    combined_errors = np.array(combined_errors)
    relation_errors = np.array(relation_errors)

    # up until median
    partition_low = relation_errors[np.where(combined_errors < quartiles[1])[0]]

    # median to third quartile
    partition_med = relation_errors[np.where((quartiles[1] < combined_errors) & (combined_errors < quartiles[2]))[0]]

    # above third quartile
    partition_high = relation_errors[np.where(quartiles[2] < combined_errors)[0]]

    # bins for overlap computing
    min_bin = min(np.min(partition_low), np.min(partition_high))  # Global minimum
    max_bin = max(np.max(partition_high), np.max(partition_high))  # Global maximum
    bin_edges = np.linspace(min_bin, max_bin, n_bins + 1)

    # histogram heights are occurrence frequencies
    hist_low, _ = np.histogram(partition_low, bins=bin_edges, density=False)
    hist_high, _ = np.histogram(partition_high, bins=bin_edges, density=False)

    # normalise such that summing histogram 'heights' give you unity.
    hist_low = hist_low / sum(hist_low)
    hist_high = hist_high / sum(hist_high)

    # overlap coefficient
    min_overlap = np.minimum(hist_low, hist_high)
    overlap_coefficients.append(np.sum(min_overlap))

    # plot
    fig = go.Figure(data=[
        go.Histogram(
            x=partition_low, nbinsx=n_bins, histnorm='probability', name='Low', zorder=1,
            marker=dict(color=qualitative.Plotly[0], opacity=0.5), bingroup='overlay'
        ),
        go.Histogram(
            x=partition_med, nbinsx=n_bins, histnorm='probability', name='Med', zorder=0,
            marker=dict(color=qualitative.Plotly[1], opacity=0.5), bingroup='overlay'
        ),
        go.Histogram(
            x=partition_high, nbinsx=n_bins, histnorm='probability', name='High', zorder=2,
            marker=dict(color=qualitative.Plotly[2], opacity=0.5), bingroup='overlay'
        ),
    ])
    fig.update_xaxes(title_text="Relation Error")
    fig.update_yaxes(title_text="Probability")
    fig.update_xaxes(range=xlims)
    if save:
        pio.write_html(fig, file=os.path.join(save_path, f'{save_name}.html'), auto_open=False)
    if vis:
        fig.show()

# print
overlap_coefficients = np.array(overlap_coefficients)
print(f'Model={model}, Dataset={data_name}')
print(f'overlap coefficients = {overlap_coefficients.mean():.4g} +- {overlap_coefficients.std():.4g}')
