#%%
# Main plot on RTT occlusion is generated by manual stitching of `plot_subject_occlusion.py` and `plot_occlusion.py`
# ! Actually, deprecating in favor of plot_rtt_occlusion
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import os
import subprocess
import time

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.plotting import prep_plt, colormap, MARKER_SIZE

eval_set = 'rtt_subject'
df = pd.read_csv(f'data/analysis_metrics/{eval_set}_occ.csv')

session_0d_csv = f'data/analysis_metrics/rtt_session_occ.csv'
session_df = pd.read_csv(session_0d_csv)

session_df = session_df[session_df['variant'].str.contains('_0d')]
session_df['variant'] = session_df['variant'].str.replace('_0d', '_0s')
# scratch_df = session_df[(session_df['variant'] == 'scratch_0s-sweep-full_scratch')]
# scratch_df['variant'] = 'scratch_transfer_0s-sweep-full_scratch'
# scratch_df = session_df[(session_df['variant'] == 'scratch_0s-sweep-simple_scratch')]
# scratch_df['variant'] = 'scratch_transfer_0s-sweep-simple_scratch'
df = pd.concat([df, session_df])

def stem_map(variant):
    stem = '_'.join(variant.split('-')[0].split('_')[:-1])
    if 'transfer' in variant and 'transfer' not in stem:
        stem += '_transfer'
    return stem

def day_of(variant):
    print(variant)
    day = variant.split('-')[0].split('_')[-1][:-1] # trim "d" from "0d"
    if 'transfer' in day:
        day = day[len('transfer'):]
    return int(day)

df['variant_stem'] = df.apply(lambda row: stem_map(row.variant), axis=1)
print(df.variant_stem.unique())
df['subj'] = df.apply(lambda row: day_of(row.variant), axis=1)

# Amount of time (s) in Loco, new subject data. -60 since that's the Indy calib time
subj_session_to_time_map = {
    0: 0,
    2: 4245 - 60,
    5: 11248 - 60,
    10: 20325 - 60,
}
df['subj_hr'] = df.apply(lambda row: subj_session_to_time_map[row.subj] / 60 / 60, axis=1)
df['subj_min'] = df.apply(lambda row: subj_session_to_time_map[row.subj] / 60, axis=1)

target_df = df
# print(target_df)
#%%
f = plt.figure(figsize=(4, 3.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)


subset_variant = [
    'scratch', 
    'base_45m_200h', 
    'big_350m_2kh', # exclude 45m_2kh transfer, not really different, just visual noise
    'scratch_transfer',
    'base_45m_200h_transfer',
    'big_350m_2kh_transfer',]
y = 'eval_r2'

def marker_style_map(variant_stem):
    if '350m' in variant_stem:
        return 'P'
    if variant_stem in ['NDT2 Expert', 'NDT3 Expert', 'scratch', 'scratch_transfer', 'NDT3 mse', 'wf', 'ole']:
        return 'X'
    else:
        return 'o'
marker_dict = {
    k: marker_style_map(k) for k in target_df['variant_stem'].unique()
}
style_map = {
    'scratch': (5,0),
    'base_45m_200h': (5, 0),
    'base_45m_2kh': (5, 0),
    'big_350m_2kh': (5, 0),
    'scratch_transfer': (5, 2),
    'base_45m_200h_transfer': (5, 2),
    'big_350m_2kh_transfer': (5, 2),
}

# subset_df = target_df[target_df['variant_stem'].isin(order)]
# sns.barplot(
#     data=target_df,
#     x='subj_hr',
#     y=y,
#     hue='variant_stem',
#     # palette=colormap,
#     hue_order=order,
#     ax=ax,
#     # alpha=0.8,
#     # errorbar=None,
# )
# ax.legend().remove()
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.yaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, alpha=0.5)

subset_df = target_df[target_df.variant_stem.isin(subset_variant)]
subset_df = subset_df[subset_df['subj_hr'] > 0] # annotate as horizontal lines to connect panels with `plot_oclcusion` and avoid weird log spacing
x = 'subj_hr'
sns.lineplot(
        data=subset_df, 
        x=x, 
        y=y, 
        hue='variant_stem', 
        palette=colormap,
        style='variant_stem',
        # ci=None,
        dashes=style_map,
        ax=ax, 
        alpha=0.8, # Lighten it up
        errorbar='sd',
    )

mean_df = subset_df[[
        'subj_min', 'subj_hr', y, 'variant_stem', 'seed'
]]
mean_df = mean_df.groupby(['subj_hr', 'variant_stem']).mean().reset_index()
sns.scatterplot(
    data=mean_df, 
    x=x, 
    y=y, 
    hue='variant_stem',
    palette=colormap,
    style='variant_stem',
    markers=marker_dict,
    ax=ax, 
    s=MARKER_SIZE,
    legend=False,
    # legend=True,
    alpha=0.8,
)

# Draw dashed lines for subj_hr == 0 perf
for variant in ['scratch', 'base_45m_200h', 'big_350m_2kh']:
    perf = target_df[(target_df['subj_hr'] == 0) & (target_df['variant_stem'] == variant)][y].values[0]
    ax.axhline(perf, linestyle=':', color=colormap[variant], alpha=0.5)

ax.legend().remove()
ax.set_ylabel("")
# ax.annotate('$R^2$', xy=(0, 1), xytext=(-ax.yaxis.labelpad - 15, 1),
            # xycoords='axes fraction', textcoords='offset points',
            # ha='center', va='center', fontsize=24)
ax.set_xlabel('')
# ax.set_xlabel('Cross-subject time (hr)')
ax.set_ylim(0.1, 0.8)
ax.set_yticks(np.arange(0.1, 0.8, 0.1))
ax.set_yticklabels([])
# ax.set_yticks([0.3, 0.5, 0.7])
# set minor grid every 0.1
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))

# ax.set_xscale('symlog', linthresh=0.1)
# ax.set_xlim(-0.01, 10)
# ax.set_xlim(2000, 25000)
# Aesthetics todo
# Figure out overall layout / fontsize
# Insert variant_stem labels at start of plot