#%%
# Main plot on RTT occlusion is generated by manual stitching of `plot_subject_occlusion.py` and `plot_occlusion.py`
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import os
import subprocess
import time

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.plotting import prep_plt, colormap, MARKER_SIZE

eval_set = 'rtt'
# eval_set = 'cursor'
df = pd.read_csv(f'data/analysis_metrics/{eval_set}_session_occ.csv')

def stem_map(variant):
    stem = '_'.join(variant.split('-')[0].split('_')[:-1])
    return stem

def day_of(variant):
    day = variant.split('-')[0].split('_')[-1][:-1] # trim "d" from "0d"
    return int(day)

df['variant_stem'] = df.apply(lambda row: stem_map(row.variant), axis=1)
df['day'] = df.apply(lambda row: day_of(row.variant), axis=1)

day_to_time_map = {
    'rtt': {
        0: 0,
        4: 793 - 60,
        8: 1992 - 60,
        60: 4310 - 60,
        120: 9932 - 60,
        },
    'cursor': {
        0: 0, # 59 calib
        1: 290 - 59,
        3: 515 - 59,
        4: 743 - 59,
        5: 979 - 59,
        6: 1210 - 59,
        7: 1441 - 59,
        },
}
df['daytime_hr'] = df.apply(lambda row: day_to_time_map[eval_set][row.day] / 60 / 60, axis=1)
df['daytime_min'] = df.apply(lambda row: day_to_time_map[eval_set][row.day] / 60, axis=1)


target_df = df
#%%
f = plt.figure(figsize=(3., 3.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)

ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
ax.yaxis.grid(True, which='minor', linestyle='-', linewidth=0.5, alpha=0.5)

subset_variant = [
    'scratch',
    'base_45m_200h',
    # 'base_45m_2kh',
    'big_350m_2kh'
]
y = 'eval_r2'

def marker_style_map(variant_stem):
    if '350m' in variant_stem:
        return 'P'
    if variant_stem in ['NDT2 Expert', 'NDT3 Expert', 'scratch', 'NDT3 mse', 'wf', 'ole']:
        return 'X'
    else:
        return 'o'
marker_dict = {
    k: marker_style_map(k) for k in target_df['variant_stem'].unique()
}
x = 'daytime_hr' if eval_set == 'rtt' else 'daytime_min'
# x = 'daytime_min'
target_df = target_df[target_df.variant_stem.isin(subset_variant)]
print(target_df[x])
subset_df = target_df[target_df[x] > 0] # annotate as horizontal lines to connect panels with `plot_oclcusion`
sns.lineplot(
        data=subset_df,
        x=x,
        y=y,
        hue='variant_stem',
        palette=colormap,
        ax=ax,
        alpha=0.8, # Lighten it up
        err_kws={'alpha': 0.05},  # This makes the error band lighter
        errorbar='sd',
    )
def get_mean_perf(df):
    return df.groupby(['variant_stem', 'eval_set', 'day', 'daytime_hr']).agg({
        'eval_r2': 'mean',   # Take mean of eval_r2 across seeds
    }).reset_index()

mean_df = get_mean_perf(subset_df)

sns.scatterplot(
    data=mean_df,
    x=x,
    y=y,
    hue='variant_stem',
    palette=colormap,
    style='variant_stem',
    markers=marker_dict,
    ax=ax,
    s=MARKER_SIZE,
    legend=False,
    # legend=True,
    alpha=0.8,
)

ax.legend().remove()
ax.set_ylabel("")
# ax.annotate('$R^2$', xy=(0, 1), xytext=(-ax.yaxis.labelpad - 15, 1),
            # xycoords='axes fraction', textcoords='offset points',
            # ha='center', va='center', fontsize=24)
if eval_set == 'rtt':
    ax.set_xlabel('Cross-session data (hr)')
    ax.set_xlabel('') # Manual session join
elif eval_set == 'cursor':
    ax.set_xlabel('Cross-session data (min)')

ax.set_ylim(0.1, 0.8)
ax.set_yticks(np.arange(0.1, 0.8, 0.2))
# set minor ticks
ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
# hide every other tick

# ax.spines['left'].set_position(('axes', -0.05))  # Adjust as needed
from matplotlib.ticker import SymmetricalLogLocator

for variant in subset_variant:
    perf = target_df[(target_df[x] == 0) & (target_df['variant_stem'] == variant)][y].values[0]
    ax.axhline(perf, linestyle=':', color=colormap[variant], alpha=0.8)
ax.set_xticks([0.5, 1])
ax.set_xticklabels([0.5, 1])
ax.text(1.0, -0.03, 'hr', transform=ax.transAxes, ha='right', va='top', fontsize=22)

# if eval_set == 'rtt':
#     ax.set_ylim(0.1, 0.8)
#     ax.set_yticks([0.3, 0.5, 0.7])
#     ax.set_xlim(-0.01, 10)
#     ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
#     ax.set_xscale('symlog', linthresh=0.1)
#     ax.xaxis.set_minor_locator(SymmetricalLogLocator(linthresh=0.1, base=10))

# Aesthetics todo
# Figure out overall layout / fontsize
# Insert variant_stem labels at start of plot

#%%
# Make an empty plot with annotation, $R^2$
f = plt.figure(figsize=(4, 3.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)
# turn grid off
ax.grid(False)
ax.annotate('$R^2$', xy=(0.5, 0.5), xytext=(-ax.yaxis.labelpad - 15, 1),
            xycoords='axes fraction', textcoords='offset points',
            ha='center', va='center', fontsize=24)
plt.show()