#%%
# Main plot on RTT occlusion is generated by manual stitching of `plot_subject_occlusion.py` and `plot_occlusion.py`
# ! Actually, deprecating in favor of plot_rtt_occlusion
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import os
import subprocess
import time

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.plotting import prep_plt, colormap, MARKER_SIZE

eval_set = 'rtt_subject'
df = pd.read_csv(f'data/analysis_metrics/{eval_set}_occ.csv')

session_0d_csv = f'data/analysis_metrics/rtt_session_occ.csv'
session_df = pd.read_csv(session_0d_csv)

session_df = session_df[session_df['variant'].str.contains('_0d')]
session_df['variant'] = session_df['variant'].str.replace('_0d', '_0s')
# scratch_df = session_df[(session_df['variant'] == 'scratch_0s-sweep-full_scratch')]
# scratch_df['variant'] = 'scratch_transfer_0s-sweep-full_scratch'
# scratch_df = session_df[(session_df['variant'] == 'scratch_0s-sweep-simple_scratch')]
# scratch_df['variant'] = 'scratch_transfer_0s-sweep-simple_scratch'
df = pd.concat([df, session_df])

def stem_map(variant):
    stem = '_'.join(variant.split('-')[0].split('_')[:-1])
    if 'transfer' in variant and 'transfer' not in stem:
        stem += '_transfer'
    return stem

def day_of(variant):
    print(variant)
    day = variant.split('-')[0].split('_')[-1][:-1] # trim "d" from "0d"
    if 'transfer' in day:
        day = day[len('transfer'):]
    return int(day)

df['variant_stem'] = df.apply(lambda row: stem_map(row.variant), axis=1)
print(df.variant_stem.unique())
df['subj'] = df.apply(lambda row: day_of(row.variant), axis=1)

# Amount of time (s) in Loco, new subject data. -60 since that's the Indy calib time
subj_session_to_time_map = {
    0: 0,
    2: 4245 - 60,
    5: 11248 - 60,
    10: 20325 - 60,
}
df['subj_hr'] = df.apply(lambda row: subj_session_to_time_map[row.subj] / 60 / 60, axis=1)
df['subj_min'] = df.apply(lambda row: subj_session_to_time_map[row.subj] / 60, axis=1)

target_df = df
# print(target_df)
#%%
f = plt.figure(figsize=(3., 3.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)


subset_variant = [
    'scratch',
    'base_45m_200h',
    'big_350m_2kh', # exclude 45m_2kh transfer, not really different, just visual noise
    'scratch_transfer',
    'base_45m_200h_transfer',
    'big_350m_2kh_transfer',
]
y = 'eval_r2'

def marker_style_map(variant_stem):
    if '350m' in variant_stem:
        return 'P'
    if variant_stem in ['NDT2 Expert', 'NDT3 Expert', 'scratch', 'scratch_transfer', 'NDT3 mse', 'wf', 'ole']:
        return 'X'
    else:
        return 'o'
marker_dict = {
    k: marker_style_map(k) for k in target_df['variant_stem'].unique()
}
style_map = {
    'scratch': (5,0),
    'base_45m_200h': (5, 0),
    'base_45m_2kh': (5, 0),
    'big_350m_2kh': (5, 0),
    'scratch_transfer': (5, 2),
    'base_45m_200h_transfer': (5, 2),
    'big_350m_2kh_transfer': (5, 2),
}

# subset_df = target_df[target_df['variant_stem'].isin(order)]
# sns.barplot(
#     data=target_df,
#     x='subj_hr',
#     y=y,
#     hue='variant_stem',
#     # palette=colormap,
#     hue_order=order,
#     ax=ax,
#     # alpha=0.8,
#     # errorbar=None,
# )
# ax.legend().remove()
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.yaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, alpha=0.5)

subset_df = target_df[target_df.variant_stem.isin(subset_variant)]
subset_df = subset_df[subset_df['subj_hr'] > 0] # annotate as horizontal lines to connect panels with `plot_oclcusion` and avoid weird log spacing
x = 'subj_hr'
sns.lineplot(
        data=subset_df,
        x=x,
        y=y,
        hue='variant_stem',
        palette=colormap,
        style='variant_stem',
        # ci=None,
        dashes=style_map,
        ax=ax,
        alpha=0.8, # Lighten it up
        err_kws={'alpha': 0.05},  # This makes the error band lighter
        # errorbar='sd',
        legend=False,
    )

mean_df = subset_df[[
        'subj_min', 'subj_hr', y, 'variant_stem', 'seed'
]]
mean_df = mean_df.groupby(['subj_hr', 'variant_stem']).mean().reset_index()
sns.scatterplot(
    data=mean_df,
    x=x,
    y=y,
    hue='variant_stem',
    palette=colormap,
    style='variant_stem',
    markers=marker_dict,
    ax=ax,
    s=MARKER_SIZE,
    legend=False,
    # legend=True,
    alpha=0.8,
)

ax.legend().remove()
ax.set_ylabel("")
# ax.annotate('$R^2$', xy=(0, 1), xytext=(-ax.yaxis.labelpad - 15, 1),
            # xycoords='axes fraction', textcoords='offset points',
            # ha='center', va='center', fontsize=24)
ax.set_xlabel('')
# ax.set_xlabel('Cross-subject time (hr)')
ax.set_ylim(0.1, 0.8)
ax.set_yticks(np.arange(0.1, 0.8, 0.2))
ax.set_yticklabels([])
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.text(1.0, -0.03, 'hrs', transform=ax.transAxes, ha='right', va='top', fontsize=22)
ax.set_xscale('log')
ax.set_xticks([1, 2, 4])
ax.set_xticklabels([1, 2, 4])
ax.xaxis.set_minor_locator(plt.NullLocator())  # Remove minor ticks
ax.tick_params(axis='x', which='minor', bottom=False)  # Ensure minor ticks are not drawn

# Draw dashed lines for subj_hr == 0 perf
for variant in ['scratch', 'base_45m_200h', 'big_350m_2kh']:
    perf = target_df[(target_df['subj_hr'] == 0) & (target_df['variant_stem'] == variant)][y].values[0]
    ax.axhline(perf, linestyle=':', color=colormap[variant], alpha=0.8)


# ax.set_xlim(-0.01, 10)
# ax.set_xlim(2000, 25000)
# Aesthetics todo
# Figure out overall layout / fontsize
# Insert variant_stem labels at start of plot