import os
from collections import defaultdict

import numpy as np
import plotly.io as pio
from plotly.subplots import make_subplots

from utils.orchestrator import Evaluator


# parameters for generating plots
batch_size = 4096           # batch size to use to run through the entire dataset
ood_variant = 'test'        # test, offset
save = True
vis = True


# set consistent limits within each dataset
lims_dict = {
    'rwth': ((-0.005, 0.03), (-0.5, 2.5)),  # use None to not set lims eg. ((-0.005, 0.03), None)
    'recl': ((-0.005, 0.025), (-0.4, 1.0)),
    'rett': ((-0.005, 0.02), (-2.0, 6.0)),
    'rtraffic': ((-0.01, 0.04), (-0.02, 0.3)),
    'rillness': ((0.0, 0.015), (-100, 3000))
}


# check ood_variant
allowed_ood_variant = ['offset', 'test']
ood_variant_not_implemented_error = f'ood_variant={ood_variant} not in {allowed_ood_variant}.'
assert ood_variant in allowed_ood_variant, ood_variant_not_implemented_error
log_dir_name = f'plots_ood_{ood_variant}'

# for collecting figures
figs = defaultdict(dict)

# loop through datasets
data_names = ['recl', 'rwth', 'rtraffic', 'rillness', 'rett']
for i, data_name in enumerate(data_names):
    print(f'\ndata_name: {data_name} ({i+1} / {len(data_names)})')

    # init evaluator
    expt_name = data_name
    evaluator = Evaluator(expt_name)

    # set ood variation
    if ood_variant == 'offset':
        mode = 'train'
        data_stds = evaluator.config['data_stds'].reshape(-1).numpy()
        offsets = {
            'oc': np.linspace(0, data_stds[0], 3),
            'uc': np.linspace(0, data_stds[2], 3),
            'fc': np.linspace(0, data_stds[0], 3)
        }
        n_data = 10000      # faster computation by not using all data
    elif ood_variant == 'test':
        mode = 'test'
        offsets = {task: np.array([0]) for task in evaluator.allowed_tasks}
        n_data = None       # use all data
    else:
        raise NotImplementedError(f'experiment for log_dir_name={log_dir_name} not implemented.')

    # generate and save individual figures
    lims = lims_dict[data_name]
    for task in evaluator.allowed_tasks:
        fig = evaluator.plot_ce_v_re(
            offsets[task], task, batch_size, n_data, save, log_dir_name, mode
        )

        # set lims
        if lims[0]:     # if None then don't set lims
            fig.update_xaxes(range=lims[0])
        if lims[1]:     # if None then don't set lims
            fig.update_yaxes(range=lims[1])
        figs[task][data_name] = fig


# combine results into one figure and save
fig_all = make_subplots(
    rows=3, cols=5,
    vertical_spacing=0.05,
    horizontal_spacing=0.03,
    subplot_titles=['rECL', 'rWTH', 'rTraffic', 'rIllness', 'rETT'],
)
for i, task in enumerate(evaluator.allowed_tasks):
    for j, data_name in enumerate(data_names):
        fig = figs[task][data_name]
        for trace in fig.data:
            fig_all.add_trace(trace, row=i + 1, col=j + 1)


# annotations
for annotation in fig_all['layout']['annotations']:
    annotation['font'] = {'size': 40, 'family': 'Times New Roman'}

for i, (heading, y) in enumerate(zip(["OC", "UC", "FC"], [0.84, 0.49, 0.15])):
    fig_all.add_annotation(
        text=heading,
        xref="paper", yref="paper",
        x=-0.025, y=y,  # Adjust y based on row position
        xanchor="center", yanchor="middle",
        showarrow=False,
        font=dict(size=40, family="Times New Roman"),
        textangle=-90  # Rotate text vertically
    )

# remove legend
fig_all.update_layout(
    showlegend=False
)

if save:
    save_path = os.path.join('expt', 'logs', f'ce_v_re_{ood_variant}_all.html')
    pio.write_html(fig_all, file=save_path, auto_open=False)
if vis:
    fig_all.show()
