import os
import numpy as np
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from utils.orchestrator import Evaluator


# parameters
data_name = 'recl'      # recl, rwth, rtraffic, rillness, rett
batch_size = 4096       # batch size to use to run through the entire dataset
model = 'dm'            # dm, moment, timer
save = True
vis = True


# check model
allowed_model = ['dm', 'moment', 'timer']
model_not_implemented_error = f'model={model} not in {allowed_model}.'
assert model in allowed_model, model_not_implemented_error

# init evaluator
expt_name = data_name
evaluator = Evaluator(expt_name)

# general init
tasks = evaluator.allowed_tasks
save_path = os.path.join('expt', 'logs', expt_name, 'ce_v_re')
os.makedirs(save_path, exist_ok=True)

# prepare figure and strings
result_strings = []
y_lims = np.array([0, 0])     # for setting consistent y_lims, will be overwritten in the loop below
fig_dm = make_subplots(
    3, 1,
    shared_xaxes=True,
    vertical_spacing=0.1, horizontal_spacing=0.1,
    subplot_titles=[task.upper() for task in tasks]
)

for i, task in enumerate(tasks):
    # compute combined errors and relation errors
    if model == 'dm':
        combined_errors, relation_errors, baseline_relation_errors = evaluator.compute_ce_and_re_dm(
            task, batch_size
        )
    elif model == 'moment':
        context_len = 24
        combined_errors, relation_errors, baseline_relation_errors = evaluator.compute_ce_and_re_moment(
            task, batch_size, context_len
        )
    elif model == 'timer':
        combined_errors, relation_errors, baseline_relation_errors = evaluator.compute_ce_and_re_timer(
            task, batch_size
        )
    else:
        raise NotImplementedError(model_not_implemented_error)

    # for printing results
    result_strings.append(f'\nTask = {task}')
    for key, value in {
        'Relation Error': relation_errors, 'Combined Error': combined_errors, 'Baseline Error': baseline_relation_errors
    }.items():
        result_strings.append(f'{key}'.ljust(20) + f': {value.mean():.4g} +- {value.std():.4g}')

    # for plotting results
    y_lims = np.array([     # overwrite y_lims
        min((min(relation_errors), y_lims[0])),
        max((max(relation_errors), y_lims[1]))
    ])
    fig_dm.add_trace(go.Scatter(x=combined_errors, y=relation_errors, mode='markers', name=task), row=i+1, col=1)
    fig_dm.update_xaxes(title_text="Combined Error", row=i+1, col=1)
    fig_dm.update_yaxes(title_text="Relation Error", row=i+1, col=1)

# print
print(f'\nModel = {model.upper()}, Dataset = {data_name}')
for rs in result_strings:
    print(rs)

# plot
y_lims = 0.05 * (y_lims[1] - y_lims[0]) * np.array([-1, 1]) + y_lims        # extend limits by a 0.05*range
fig_dm.update_yaxes(range=y_lims)
fig_dm.update_layout(
    xaxis_title='Relation Error', yaxis_title='Combined Error', title=f"{model.upper()}: CE VS RE"
)
pio.write_html(fig_dm, file=os.path.join(save_path, f'{model}.html'), auto_open=False)
if vis:
    fig_dm.show()
