#%%
r"""
    Cross-task summary plot on the gains from scaling at different evaluation points
"""
# Demonstrate pretraining scaling. Assumes evaluation metrics have been computed and merely assembles.
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import os
import subprocess
import time

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.plotting import prep_plt, MARKER_SIZE, colormap, cont_size_palette, SIZE_PALETTE
from context_general_bci.utils import get_simple_host

pl.seed_everything(0)

# Session / task data available in total (can include both evaluation sessions and non-evaluation sessions)
pt_volume_labels = {
    'cursor': ['2.5 min', '5 min', '10 min'],
    'falcon_h1': ['15 min', '30 min', '1 hr'],
    'falcon_m1': ['1h', '2h', '4 hr'],
    'falcon_m2': ['11 min', '22 min', '44 min'],
    'rtt': ['3h', '6h', '12 hr'],
    'grasp_h': ['15 min', '30 min', '1 hr'],
    'rtt_s1': ['7 min', '14 min', '28 min'],
    'cst': ['10 min', '21 min', '42 min'],
    'eye': ['2.5h', '5h', '10 hr'],
}

# Data available in evaluation sessions
# r'64s$\times$4' for raw latex string
tune_volume_labels = {
    'cursor': ('60 s', 11),
    'falcon_h1': ('80 s', 7),
    'falcon_m1': ('2 min', 4),
    'falcon_m2': ('64 s', 4),  # Assuming None for missing numerical value
    'rtt': ('60 s', 3),  # Assuming None for missing numerical value
    'grasp_h': ('10 min', 6),  # Assuming None for missing numerical value
    'cst': ('60 s', 39),
    'rtt_s1': ('4 min', 7),
    'eye': ("40 min", 1),
}

heldin_tune_volume_labels = {
    'falcon_h1': ('8 min', 6),
    'falcon_m1': ('1 hr', 4),
    'falcon_m2': ('10 min', 4),
}
ridge_paths = [
    Path('./data/eval_metrics/ridge_falcon_h1.csv'),
    Path('./data/eval_metrics/ridge_falcon_m2.csv'),
    Path('./data/eval_metrics/ridge_falcon_m1.csv'),
    Path('./data/eval_metrics/ridge_grasp_h.csv'),
    Path('./data/eval_metrics/ridge_cursor.csv'),
    Path('./data/eval_metrics/ridge_rtt.csv'),
    Path('./data/eval_metrics/ridge_cst.csv'),
]
ridge_dfs = []
for src_path in ridge_paths:
    ridge_df = pd.read_csv(src_path)
    ridge_df['variant'] = 'linear'
    ridge_df['variant_stem'] = 'wf'
    ridge_df['eval_set'] = src_path.stem[len('ridge_'):]
    # reduce by history
    if 'h1' not in str(src_path):
        ridge_df = ridge_df[ridge_df['history'] <= 50] # 1s limit for parity with NDT3
    ridge_df = ridge_df.groupby('scale').apply(lambda x: x[x['r2'] == x['r2'].max()]).reset_index(drop=True)

    ridge_df['id'] = ridge_df['eval_set'] + '-' + ridge_df['scale'].astype(str)
    ridge_df['scale_ratio'] = ridge_df['scale']
    if 'falcon' in src_path.stem:
        ridge_df['heldin_eval_r2'] = ridge_df['heldin']
        ridge_df['eval_r2'] = ridge_df['heldout']
    else:
        ridge_df['eval_r2'] = ridge_df['r2']
    ridge_dfs.append(ridge_df)
ridge_df = pd.concat(ridge_dfs)


df_paths = [
    Path("./data/eval_metrics/mind_rtt_s1_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_cursor_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_falcon_m1_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_falcon_m2_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_grasp_h_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_cst_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_rtt_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_falcon_h1_eval_ndt3.csv"),

    Path("./data/eval_metrics/nid_cursor_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_cst_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_grasp_h_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_h1_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_m1_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_m2_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_rtt_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_eye_eval_ndt3.csv"),
]

ndt2_df_paths = [
    Path('./data/eval_metrics/falcon_h1_eval_ndt2.csv'),
    Path('./data/eval_metrics/falcon_m1_eval_ndt2.csv'),
    Path('./data/eval_metrics/falcon_m2_eval_ndt2.csv'),
    Path('./data/eval_metrics/cursor_eval_ndt2.csv'),
    Path('./data/eval_metrics/rtt_eval_ndt2.csv'),
    Path('./data/eval_metrics/grasp_h_eval_ndt2.csv'),
]
# import csvs
for src_path in df_paths:
    cur_host = get_simple_host()
    csv_host = src_path.name.split('_')[0]
    if cur_host != csv_host:
        print(src_path, cur_host, src_path.exists())
        EXPIRY = 86400 * 30
        # check datetime of import, if older than a day, reimport
        if src_path.exists() and (time.time() - os.path.getmtime(src_path) < EXPIRY):
            continue
        print(f'Copying {src_path} to {cur_host}')
        subprocess.run(f'scp {csv_host}:projects/ndt3/{src_path} ./data/eval_metrics', shell=True)
for src_path in ndt2_df_paths:
    cur_host = get_simple_host()
    if cur_host != csv_host:
        # check datetime of import, if older than a day, reimport
        if src_path.exists() and (os.path.getmtime(src_path) - os.path.getmtime(Path('./data/eval_metrics')) < 86400):
            continue
        print(f'Copying {src_path} to {cur_host}')
        subprocess.run(f'scp mind:projects/context_general_bci/{src_path} ./data/eval_metrics', shell=True)

eval_df = pd.concat([pd.read_csv(p) for p in df_paths])
if len(ndt2_df_paths) > 0:
    ndt2_eval_df = pd.concat([pd.read_csv(p) for p in ndt2_df_paths])
else:
    ndt2_eval_df = pd.DataFrame()

def stem_map(variant):
    if 'scratch' in variant:
        return 'NDT3 mse'
    # if 'scratch' in variant:
    #     return 'NDT3 Expert'
    return '_'.join(variant.split('-')[0].split('_')[:-1])

eval_df['variant_stem'] = eval_df.apply(lambda row: stem_map(row.variant), axis=1)
ndt2_eval_df['variant_stem'] = 'NDT2 Expert'
eval_df = pd.concat([eval_df, ndt2_eval_df, ridge_df])
if 'index' in eval_df.columns:
    eval_df.drop(columns=['index'], inplace=True)
eval_df.reset_index(inplace=True)

# drop 0s
eval_df = eval_df[eval_df['eval_r2'] != 0]
# print(eval_df[eval_df['eval_set'] == 'rtt'].sort_values('eval_r2', ascending=False))
# Unique by id
eval_df = eval_df.drop_duplicates(subset=['id']) # additional needed to not drop linear
eval_df = eval_df.drop_duplicates(subset=[
    'variant_stem', 'scale_ratio', 'eval_set', 'seed'
    # multi-sweep into one best candidate
])
print(eval_df['variant_stem'].unique())
print(eval_df[eval_df['variant_stem'] == 'wf']['eval_set'].unique())

print(eval_df[['history', 'variant']])
def time_str_to_minutes(time_str):
    if 's' in time_str:
        return int(time_str.split(' ')[0]) / 60
    if 'min' in time_str:
        return int(time_str.split(' ')[0])
    elif 'h' in time_str:
        return int(time_str.split(' ')[0]) * 60
    else:
        return 0

def marker_style_map(variant_stem):
    if '350m' in variant_stem:
        return 'P'
    if variant_stem in ['NDT2 Expert', 'NDT3 Expert', 'NDT3 mse']:
        return 'X'
    elif variant_stem in ['wf', 'ole']: # ! only in intro plot to distinguish
        return 'd'
    else:
        return 'o'

def volume_map(variant_stem):
    if 'Expert' in variant_stem:
        return 0
    elif '200h' in variant_stem:
        return 200
    elif '1kh' in variant_stem:
        return 1000
    elif '2kh' in variant_stem:
        return 2000
    elif '25h' in variant_stem:
        return 25
    elif '70h' in variant_stem:
        return 70
    elif 'min' in variant_stem:
        return 2
    else:
        return 0
eval_df['pt_volume'] = eval_df.variant_stem.apply(volume_map)

# eval_df['marker_size'] = eval_df['pt_volume'] * 30
eval_df['marker_size']  = MARKER_SIZE
eval_df['marker_style'] = eval_df.variant_stem.apply(marker_style_map)
marker_dict = {
    k: marker_style_map(k) for k in eval_df['variant_stem'].unique()
}
eval_df['session_time'] = eval_df.apply(lambda row: time_str_to_minutes(tune_volume_labels[row.eval_set][0]), axis=1)
eval_df['scaled_session_time'] = eval_df['scale_ratio'] * eval_df['session_time']
eval_df['task_time'] = eval_df.apply(lambda row: time_str_to_minutes(pt_volume_labels[row.eval_set][-1]), axis=1)
eval_df['scaled_task_time'] = eval_df['scale_ratio'] * eval_df['task_time']



#%%
from statannotations.Annotator import Annotator
FIGURE = 'ALL'

BASELINE = 'NDT3 mse'
BASELINE = ''

variants = [
    # 'wf', # removing for clarity
    # 'NDT2 Expert', # removing for clarity
    # 'NDT3 mse',
    # 'base_45m_25h',
    # 'base_45m_70h',
    'base_45m_200h',
    'big_350m_200h',
    # 'base_45m_1kh',
    # 'base_45m_1kh_human',
    # 'base_45m_1kh_breadth',
    'base_45m_2kh',
    'big_350m_2kh',
]

labels = {
    'NDT2 Expert': 'NDT2',
    'NDT3 mse': 'Scratch',
    'wf': 'Wiener Filter',
    'big_350m_2kh': '350M 2kh',
    'big_350m_200h': '350M 200h',
    'base_45m_2kh': '2kh',
    'base_45m_1kh': '1kh Depth',
    'base_45m_1kh_human': '1kh Human',
    'base_45m_1kh_breadth': '1kh Breadth',
    'base_45m_200h': '200h',
    'base_45m_25h': '25h',
    'base_45m_70h': '70h',
}

# take mean across seeds for visual clarity
TAKE_MEAN = True

PLOT_TREND_LINES = False
PLOT_TREND_LINES = True

SHOW_LEGEND = True # Also, legend somehow has wrong labels
SCATTER_ALPHA = 0.8
lower = 0.0
lower = None
# Note, these will only show the heldout r2 for falcon tasks
subset_df = eval_df[eval_df['variant_stem'].isin(variants)]
tasks = [
    # 'cursor',
    # 'falcon_h1',
    # 'grasp_h',
    'rtt',
    'falcon_m1',
    'falcon_m2',
    'cst',
]
subset_df = subset_df[subset_df['eval_set'].isin(tasks)]
print(subset_df.columns)

if TAKE_MEAN:
    subset_df = subset_df.groupby(['variant_stem', 'eval_set', 'pt_volume', 'scale_ratio']).agg({
        'eval_r2': 'mean',   # Take mean of eval_r2 across seeds
        # Include other columns if needed, e.g., 'marker_style' or others, if they're the same across groups
    }).reset_index()
if BASELINE:
    # Get the baseline eval_r2 for each eval_set and scale_ratio
    baseline_df = subset_df[subset_df['variant_stem'] == BASELINE].copy()
    baseline_df = baseline_df.rename(columns={'eval_r2': 'baseline_r2'})
    baseline_df = baseline_df[['eval_set', 'scale_ratio', 'baseline_r2']]

    # Merge the baseline values with the main dataframe
    subset_df = pd.merge(subset_df, baseline_df, on=['eval_set', 'scale_ratio'], how='left')

    # Subtract the baseline eval_r2 from the corresponding eval_r2 for other models
    subset_df['norm_r2'] = subset_df['eval_r2'] - subset_df['baseline_r2']

    # Drop the temporary columns
    subset_df = subset_df.drop(columns=['baseline_r2'])

    # Set y to 'norm_r2'
    y = 'norm_r2'
else:
    y = 'eval_r2'


# f = plt.figure(figsize=(4.5, 5.7), layout='constrained')
f = plt.figure(figsize=(6.5, 4.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)

# fails due to negative r2s, even if a bit more principled than mean or median
def geo_mean_overflow(iterable):
    return np.exp(np.log(iterable).mean())

colormap['wf'] = '#3A3A3A'  # Deep gray, almost black but still visible
sns.violinplot(
# sns.boxplot(
# sns.pointplot(
# sns.stripplot(
# sns.barplot(
    data=subset_df,
    x='variant_stem',
    order=variants,
    split='params',
    # estimator=np.median,
    y=y,
    # hue='variant_stem',
    palette=colormap,
    # edgecolor="black",
    # errcolor="black",
    # errwidth=1.5,
    # capsize = 0.1,
    alpha=0.1,
    ax=ax,
)
sns.stripplot(
    data=subset_df,
    x='variant_stem',
    hue='eval_set',
    order=variants,
    y=y,
    # palette=colormap,
    alpha=0.8,
    ax=ax,
    # legend=False,
)
# Put the legend off right
# relabel: cursor: 2D Cursor+Click, falcon_h1: Falcon H1, grasp_h: 1D Grasp
# Relabel: cursor: 2D Cursor+Click, falcon_h1: Falcon H1, grasp_h: 1D Grasp

# Put the legend off right
existing_labels = ax.get_xticklabels()
ax.set_xticklabels([
    labels[e.get_text()] for e in ax.get_xticklabels()
], fontsize=16, rotation=45)

handles, labels = ax.get_legend_handles_labels()
new_labels = ['2D Cursor+Click' if l == 'cursor' else
              'Falcon H1' if l == 'falcon_h1' else
              '1D Grasp' if l == 'grasp_h' else l for l in labels]
ax.legend(handles, new_labels, title='Eval Set', bbox_to_anchor=(1.05, 0.5), loc='center left')
if FIGURE == 'SUMMARY':
    ax.set_ylabel('')
    ax.text(-0.05, 1.02, '$R^2$', ha='center', va='center', transform=ax.transAxes, fontsize=24)
else:
    ax.set_ylabel('')
    ax.text(-0.08, 0.95, '$R^2$', ha='center', va='center', transform=ax.transAxes, fontsize=24)
# Clip lower
if BASELINE:
    pass
elif lower:
    ylims = ax.get_ylim()
    ax.set_ylim(lower, ylims[1])

ax.set_xlabel('')

#%%
g = sns.FacetGrid(subset_df, row="variant_stem", hue="variant_stem", aspect=15, height=.5, palette=colormap)

# Draw the densities in a few steps
g.map(sns.kdeplot, "eval_r2",
      bw_adjust=.5, clip_on=False,
      fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "eval_r2", clip_on=False, color="w", lw=2, bw_adjust=.5)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(0, .2, label, fontweight="bold", color=color,
            ha="left", va="center", transform=ax.transAxes)


g.map(label, "eval_r2")

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-.0)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)

#%%
from statannotations.Annotator import Annotator
FIGURE = 'ALL'

BASELINE = 'NDT3 mse'
BASELINE = ''

variants = [
    # 'wf', # removing for clarity
    # 'NDT2 Expert', # removing for clarity
    # 'NDT3 mse',
    # 'base_45m_25h',
    # 'base_45m_70h',
    'base_45m_200h',
    'big_350m_200h',
    # 'base_45m_1kh',
    # 'base_45m_1kh_human',
    # 'base_45m_1kh_breadth',
    'base_45m_2kh',
    'big_350m_2kh',
]

labels = {
    'NDT2 Expert': 'NDT2',
    'NDT3 mse': 'Scratch',
    'wf': 'Wiener Filter',
    'big_350m_2kh': '350M 2kh',
    'big_350m_200h': '350M 200h',
    'base_45m_2kh': '2kh',
    'base_45m_1kh': '1kh Depth',
    'base_45m_1kh_human': '1kh Human',
    'base_45m_1kh_breadth': '1kh Breadth',
    'base_45m_200h': '200h',
    'base_45m_25h': '25h',
    'base_45m_70h': '70h',
}

# take mean across seeds for visual clarity
TAKE_MEAN = True

PLOT_TREND_LINES = False
PLOT_TREND_LINES = True

SHOW_LEGEND = True # Also, legend somehow has wrong labels
SCATTER_ALPHA = 0.8
lower = 0.0
lower = None
# Note, these will only show the heldout r2 for falcon tasks
subset_df = eval_df[eval_df['variant_stem'].isin(variants)]
tasks = [
    # 'cursor',
    # 'falcon_h1',
    # 'grasp_h',
    'rtt',
    'falcon_m1',
    'falcon_m2',
    'cst',
]
subset_df = subset_df[subset_df['eval_set'].isin(tasks)]
print(subset_df.columns)

if TAKE_MEAN:
    subset_df = subset_df.groupby(['variant_stem', 'eval_set', 'pt_volume', 'scale_ratio']).agg({
        'eval_r2': 'mean',   # Take mean of eval_r2 across seeds
        # Include other columns if needed, e.g., 'marker_style' or others, if they're the same across groups
    }).reset_index()
if BASELINE:
    # Get the baseline eval_r2 for each eval_set and scale_ratio
    baseline_df = subset_df[subset_df['variant_stem'] == BASELINE].copy()
    baseline_df = baseline_df.rename(columns={'eval_r2': 'baseline_r2'})
    baseline_df = baseline_df[['eval_set', 'scale_ratio', 'baseline_r2']]

    # Merge the baseline values with the main dataframe
    subset_df = pd.merge(subset_df, baseline_df, on=['eval_set', 'scale_ratio'], how='left')

    # Subtract the baseline eval_r2 from the corresponding eval_r2 for other models
    subset_df['norm_r2'] = subset_df['eval_r2'] - subset_df['baseline_r2']

    # Drop the temporary columns
    subset_df = subset_df.drop(columns=['baseline_r2'])

    # Set y to 'norm_r2'
    y = 'norm_r2'
else:
    y = 'eval_r2'


# f = plt.figure(figsize=(4.5, 5.7), layout='constrained')
f = plt.figure(figsize=(6.5, 4.5), layout='constrained')
ax = prep_plt(f.gca(), big=True)

# fails due to negative r2s, even if a bit more principled than mean or median
def geo_mean_overflow(iterable):
    return np.exp(np.log(iterable).mean())

colormap['wf'] = '#3A3A3A'  # Deep gray, almost black but still visible
# sns.violinplot(
sns.boxplot(
# sns.pointplot(
# sns.stripplot(
# sns.barplot(
    data=subset_df,
    x='variant_stem',
    order=variants,
    # split='params',
    # estimator=np.median,
    y=y,
    # hue='variant_stem',
    palette=colormap,
    # edgecolor="black",
    # errcolor="black",
    # errwidth=1.5,
    # capsize = 0.1,
    # alpha=0.1,
    ax=ax,
)
# sns.stripplot(
#     data=subset_df,
#     x='variant_stem',
#     hue='eval_set',
#     order=variants,
#     y=y,
#     # palette=colormap,
#     alpha=0.8,
#     ax=ax,
#     # legend=False,
# )
# Put the legend off right
# relabel: cursor: 2D Cursor+Click, falcon_h1: Falcon H1, grasp_h: 1D Grasp
# Relabel: cursor: 2D Cursor+Click, falcon_h1: Falcon H1, grasp_h: 1D Grasp

# Put the legend off right
existing_labels = ax.get_xticklabels()
ax.set_xticklabels([
    labels[e.get_text()] for e in ax.get_xticklabels()
], fontsize=16, rotation=45)

handles, labels = ax.get_legend_handles_labels()
new_labels = ['2D Cursor+Click' if l == 'cursor' else
              'Falcon H1' if l == 'falcon_h1' else
              '1D Grasp' if l == 'grasp_h' else l for l in labels]
ax.legend(handles, new_labels, title='Eval Set', bbox_to_anchor=(1.05, 0.5), loc='center left')
if FIGURE == 'SUMMARY':
    ax.set_ylabel('')
    ax.text(-0.05, 1.02, '$R^2$', ha='center', va='center', transform=ax.transAxes, fontsize=24)
else:
    ax.set_ylabel('')
    ax.text(-0.08, 0.95, '$R^2$', ha='center', va='center', transform=ax.transAxes, fontsize=24)
# Clip lower
if BASELINE:
    pass
elif lower:
    ylims = ax.get_ylim()
    ax.set_ylim(lower, ylims[1])

ax.set_xlabel('')

# Aight, no significance really...
# Define the pairs to compare
pairs = [
    # ('NDT3 mse', 'base_45m_25h'),
    # ('base_45m_25h', 'base_45m_70h'),
    # ('base_45m_70h', 'base_45m_200h'),
    ('base_45m_200h', 'base_45m_2kh'),
    ('base_45m_200h', 'big_350m_200h'),
    ('base_45m_2kh', 'big_350m_2kh'),
    ('big_350m_200h', 'big_350m_2kh'),
]

# Create the Annotator object
annotator = Annotator(ax, pairs, data=subset_df, x='variant_stem', y=y, order=[
    # 'wf',
    # 'NDT2 Expert',
    # 'NDT3 mse',
    # 'base_45m_25h',
    # 'base_45m_70h',
    'base_45m_200h',
    'big_350m_200h',
    # 'base_45m_1kh',
    # 'base_45m_1kh_human',
    # 'base_45m_1kh_breadth',
    'base_45m_2kh',
    'big_350m_2kh',
])
# annotator = Annotator(ax, pairs, data=subset_df, x='variant_stem', y=y, order=[
    # e.get_text() for e in ax.get_xticklabels()
    # labels[e.get_text()] for e in ax.get_xticklabels()
# ])

# Perform and add the statistical annotations
annotator.configure(test='t-test_paired', text_format='star', loc='inside')
annotator.apply_and_annotate()
