import os
from collections import defaultdict

import numpy as np
import plotly.io as pio
from plotly.subplots import make_subplots

from utils.orchestrator import Evaluator


# parameters for generating plots
batch_size = 4096           # batch size to use to run through the entire dataset
ood_variant = 'test'        # test, offset
save = True
vis = True


# check ood_variant
allowed_ood_variant = ['offset', 'test']
ood_variant_not_implemented_error = f'ood_variant={ood_variant} not in {allowed_ood_variant}.'
assert ood_variant in allowed_ood_variant, ood_variant_not_implemented_error
log_dir_name = f'plots_ood_{ood_variant}'

# set consistent limits within each dataset
lims_dict = {
    'rwth': ((-0.005, 0.03), (-0.5, 2.5)),  # use None to not set lims eg. ((-0.005, 0.03), None)
    'recl': ((-0.005, 0.025), (-0.4, 1.0)),
    'rett': ((-0.005, 0.02), (-2.0, 6.0)),
    'rtraffic': ((-0.01, 0.04), (-0.02, 0.3)),
    'rillness': ((0.0, 0.015), (-100, 3000))
}

# for collecting figures
metric_names = ['pts', 'rts', 'cts', 'pe', 'ce',]
figs = {
    'pts': defaultdict(dict),
    'rts': defaultdict(dict),
    'cts': defaultdict(dict),
    'pe': defaultdict(dict),
    'ce': defaultdict(dict)
}

# loop through datasets
data_names = ['recl', 'rwth', 'rtraffic', 'rillness', 'rett']
for i, data_name in enumerate(data_names):
    print(f'\ndata_name: {data_name} ({i+1} / {len(data_names)})')

    # init evaluator
    expt_name = data_name
    evaluator = Evaluator(expt_name)

    # set ood variation
    if ood_variant == 'offset':
        mode = 'train'
        data_stds = evaluator.config['data_stds'].reshape(-1).numpy()
        offsets = {
            'oc': np.linspace(0, data_stds[0], 3),
            'uc': np.linspace(0, data_stds[2], 3),
            'fc': np.linspace(0, data_stds[0], 3)
        }
        n_data = 10000      # faster computation by not using all data
    elif ood_variant == 'test':
        mode = 'test'
        offsets = {task: np.array([0]) for task in evaluator.allowed_tasks}
        n_data = None       # use all data
    else:
        raise NotImplementedError(f'experiment for log_dir_name={log_dir_name} not implemented.')

    # generate figures
    lims = lims_dict[data_name]
    for task in evaluator.allowed_tasks:
        fig = evaluator.plot_m_v_re(
            offsets[task], task, batch_size, n_data, save, log_dir_name, mode
        )

        for metric_name in metric_names:
            if lims[0]:     # if None then don't set lims
                fig[metric_name].update_xaxes(range=lims[0])
            if lims[1]:     # if None then don't set lims
                fig[metric_name].update_yaxes(range=lims[1])
            figs[metric_name][task][data_name] = fig[metric_name]


# combine results into one figure
fig_alls = {}
for metric_name in metric_names:
    fig_all = make_subplots(
        rows=3, cols=5,
        vertical_spacing=0.05,
        horizontal_spacing=0.03,
        subplot_titles=['rECL', 'rWTH', 'rTraffic', 'rIllness', 'rETT'],
    )
    for i, task in enumerate(evaluator.allowed_tasks):
        for j, data_name in enumerate(data_names):
            fig = figs[metric_name][task][data_name]
            for trace in fig.data:
                fig_all.add_trace(trace, row=i + 1, col=j + 1)

    # annotations
    for annotation in fig_all['layout']['annotations']:
        annotation['font'] = {'size': 40, 'family': 'Times New Roman'}

    for i, (heading, y) in enumerate(zip(["OC", "UC", "FC"], [0.84, 0.49, 0.15])):
        fig_all.add_annotation(
            text=heading,
            xref="paper", yref="paper",
            x=-0.025, y=y,  # Adjust y based on row position
            xanchor="center", yanchor="middle",
            showarrow=False,
            font=dict(size=40, family="Times New Roman"),
            textangle=-90  # Rotate text vertically
        )

    # remove legend
    fig_all.update_layout(
        showlegend=False
    )

    fig_alls[metric_name] = fig_all

    if save:
        save_path = os.path.join('expt', 'logs', f'M{metric_name}_v_re_{ood_variant}_all.html')
        pio.write_html(fig_all, file=save_path, auto_open=False)

# generate all plot first then plot
if vis:
    for metric_name in metric_names:
        fig_alls[metric_name].update_layout(title_text=f"{metric_name.upper()}", title_x=0.5)
        fig_alls[metric_name].show()
