import os
import numpy as np
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.colors import qualitative
from utils.orchestrator import Evaluator


# parameters
batch_size = 4096       # batch size to use to run through the entire dataset
n_bins_dict = {         # number of histogram bins for each dataset
    'recl': 1000,
    'rwth': 1000,
    'rtraffic': 1000,
    'rillness': 100,
    'rett': 1000,
}
xlims_dict = {          # xlims of histogram (relation error), set to None to not set lims eg. 'recl': None
    'recl': (0, 0.3),
    'rwth': (0, 0.3),
    'rtraffic': (0, 0.02),
    'rillness': (0, 2000),
    'rett': (0, 1.0)
}
save = True
vis = True


# figure
fig_all = make_subplots(
    rows=3, cols=5,
    vertical_spacing=0.05,
    horizontal_spacing=0.03,
    subplot_titles=['rECL', 'rWTH', 'rTraffic', 'rIllness', 'rETT'],
)

# loop through datasets
model_names = ['dm', 'moment', 'timer']
data_names = ['recl', 'rwth', 'rtraffic', 'rillness', 'rett']
for i, data_name in enumerate(data_names):
    print(f'\ndata_name: {data_name} ({i+1} / {len(data_names)})')

    # init evaluator
    expt_name = data_name
    evaluator = Evaluator(expt_name)

    # general init
    tasks = evaluator.allowed_tasks
    save_path = os.path.join('expt', 'logs', data_name, 'quartile')
    os.makedirs(save_path, exist_ok=True)

    # get train set quartiles
    combined_errors = []
    for task in evaluator.allowed_tasks:
        ce, _, _ = evaluator.compute_ce_and_re_dm(task, batch_size, mode='train')
        combined_errors += ce.tolist()
    combined_errors = np.array(combined_errors)
    quartiles = np.percentile(combined_errors, [25, 50, 75])

    # loop through models
    for j, model_name in enumerate(model_names):

        # compute combined errors and relation errors over all the tasks
        combined_errors, relation_errors = [], []
        for task in evaluator.allowed_tasks:
            if model_name == 'dm':
                ce, re, _ = evaluator.compute_ce_and_re_dm(task, batch_size)
            elif model_name == 'moment':
                context_len = 24
                ce, re, _ = evaluator.compute_ce_and_re_moment(task, batch_size, context_len)
            elif model_name == 'timer':
                ce, re, _ = evaluator.compute_ce_and_re_timer(task, batch_size)
            else:
                raise NotImplementedError(f'model={model_name} not in {model_names}.')
            combined_errors += ce.tolist()
            relation_errors += re.tolist()
        combined_errors = np.array(combined_errors)
        relation_errors = np.array(relation_errors)

        # up until median
        partition_low = relation_errors[np.where(combined_errors < quartiles[1])[0]]

        # median to third quartile
        partition_med = relation_errors[np.where((quartiles[1] < combined_errors) & (combined_errors < quartiles[2]))[0]]

        # above third quartile
        partition_high = relation_errors[np.where(quartiles[2] < combined_errors)[0]]

        # plot
        n_bins = n_bins_dict[data_name]
        fig_all.add_trace(
            go.Histogram(
                x=partition_low, nbinsx=n_bins, histnorm='probability', name='Low', zorder=1,
                marker=dict(color=qualitative.Plotly[0], opacity=0.85),
            ), row=j + 1, col=i + 1)
        fig_all.add_trace(
            go.Histogram(
                x=partition_med, nbinsx=n_bins, histnorm='probability', name='Med', zorder=0,
                marker=dict(color=qualitative.Plotly[1], opacity=0.85),
            )
            , row=j + 1, col=i + 1)
        fig_all.add_trace(
            go.Histogram(
                x=partition_high, nbinsx=n_bins, histnorm='probability', name='High', zorder=2,
                marker=dict(color=qualitative.Plotly[2], opacity=0.85),
            ), row=j + 1, col=i + 1)

        # format axes
        xlims = xlims_dict[data_name]
        fig_all.update_xaxes(range=xlims, row=j + 1, col=i + 1)
        fig_all.update_yaxes(showticklabels=False, row=j + 1, col=i + 1)
        if j == len(model_names) - 1:
            fig_all.update_xaxes(
                title_text="Relation Error", title_font=dict(family="Times New Roman", size=30),
                row=j + 1, col=i + 1
            )
fig_all.update_layout(showlegend=False)

# annotations
for annotation in fig_all['layout']['annotations']:
    annotation['font'] = {'size': 40, 'family': 'Times New Roman'}

for i, (heading, y) in enumerate(zip(["DM", "MOMENT", "TIMER"], [0.84, 0.49, 0.15])):
    fig_all.add_annotation(
        text=heading,
        xref="paper", yref="paper",
        x=-0.035, y=y,  # Adjust y based on row position
        xanchor="center", yanchor="middle",
        showarrow=False,
        font=dict(size=40, family="Times New Roman"),
        textangle=-90  # Rotate text vertically
    )

if save:
    pio.write_html(fig_all, file=os.path.join('expt', 'logs', 'quartile_hist.html'), auto_open=False)

if vis:
    fig_all.show()
