#%%
r"""
    Primary + secondary scaling results.
"""
# Demonstrate pretraining scaling. Assumes evaluation metrics have been computed and merely assembles.
import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import os
import subprocess
import time

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.plotting import prep_plt, MARKER_SIZE, SIZE_PALETTE
from context_general_bci.plotting.styleguide import colormap, SIZE_PALETTE, cont_size_palette
from context_general_bci.utils import get_simple_host

pl.seed_everything(0)

# Session / task data available in total (can include both evaluation sessions and non-evaluation sessions)
pt_volume_labels = {
    'cursor': ['2.5 min', '5 min', '10 min'],
    'falcon_h1': ['15 min', '30 min', '1 hr'],
    'falcon_m1': ['1h', '2h', '4 hr'],
    'falcon_m2': ['11 min', '22 min', '44 min'],
    'rtt': ['3h', '6h', '12 hr'],
    'grasp_h': ['15 min', '30 min', '1 hr'],
    'rtt_s1': ['7 min', '14 min', '28 min'],
    'cst': ['10 min', '21 min', '42 min'],
    'eye': ['2.5h', '5h', '10 hr'],
}
task_time = {
    'cursor': 10,
    'falcon_h1': 1 * 60,
    'falcon_m1': 2 * 60,
    'falcon_m2': 44,
    'rtt': 12 * 60,
    'grasp_h': 1 * 60,
    'rtt_s1': 28,
    'cst': 42,
    'eye': 10 * 60,
}

# Data available in evaluation sessions
# r'64s$\times$4' for raw latex string
tune_volume_labels = {
    'cursor': ('60 s', 11),
    'falcon_h1': ('80 s', 7),
    'falcon_m1': ('2 min', 4),
    'falcon_m2': ('64s', 4),  # Assuming None for missing numerical value
    'rtt': ('60 s', 3),  # Assuming None for missing numerical value
    'grasp_h': ('10 min', 6),  # Assuming None for missing numerical value
    'cst': ('60 s', 39),
    'rtt_s1': ('4 min', 7),
    'eye': ("40 min", 1),
}

# Enforced mins for visual clarity
y_min = {
    'rtt': 0.35,
}

y_min = {} # Comment out to enforce y min

heldin_tune_volume_labels = {
    'falcon_h1': ('8 min', 6),
    'falcon_m1': ('1 hr', 4),
    'falcon_m2': ('10 min', 4),
}
ridge_paths = [
    Path('./data/eval_metrics/ridge_falcon_h1.csv'),
    Path('./data/eval_metrics/ridge_falcon_m2.csv'),
    Path('./data/eval_metrics/ridge_falcon_m1.csv'),
    Path('./data/eval_metrics/ridge_grasp_h.csv'),
    Path('./data/eval_metrics/ridge_cursor.csv'),
    Path('./data/eval_metrics/ridge_rtt.csv'),
    Path('./data/eval_metrics/ridge_cst.csv'),
]
ridge_dfs = []
for src_path in ridge_paths:
    ridge_df = pd.read_csv(src_path)
    ridge_df['variant'] = 'linear'
    ridge_df['variant_stem'] = 'wf'
    ridge_df['eval_set'] = src_path.stem[len('ridge_'):]
    # reduce by history
    if 'h1' not in str(src_path):
        ridge_df = ridge_df[ridge_df['history'] <= 50] # 1s limit for parity with NDT3
    ridge_df = ridge_df.groupby('scale').apply(lambda x: x[x['r2'] == x['r2'].max()]).reset_index(drop=True)

    ridge_df['id'] = ridge_df['eval_set'] + '-' + ridge_df['scale'].astype(str)
    ridge_df['scale_ratio'] = ridge_df['scale']
    if 'falcon' in src_path.stem:
        ridge_df['heldin_eval_r2'] = ridge_df['heldin']
        ridge_df['eval_r2'] = ridge_df['heldout']
    else:
        ridge_df['eval_r2'] = ridge_df['r2']
    ridge_df['val_kinematic_r2'] = ridge_df['eval_r2'] # Not true, but spoof for now
    ridge_dfs.append(ridge_df)
ridge_df = pd.concat(ridge_dfs)


df_paths = [
    Path("./data/eval_metrics/mind_rtt_s1_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_cursor_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_falcon_m2_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_grasp_h_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_cst_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_rtt_eval_ndt3.csv"),
    Path("./data/eval_metrics/crc_falcon_h1_eval_ndt3.csv"),

    Path("./data/eval_metrics/nid_cursor_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_cst_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_grasp_h_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_h1_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_m1_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_falcon_m2_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_rtt_eval_ndt3.csv"),
    Path("./data/eval_metrics/nid_eye_eval_ndt3.csv"),
]

ndt2_df_paths = [
    Path('./data/eval_metrics/falcon_h1_eval_ndt2.csv'),
    Path('./data/eval_metrics/falcon_m1_eval_ndt2.csv'),
    Path('./data/eval_metrics/falcon_m2_eval_ndt2.csv'),
    Path('./data/eval_metrics/cursor_eval_ndt2.csv'),
    Path('./data/eval_metrics/rtt_eval_ndt2.csv'),
    Path('./data/eval_metrics/grasp_h_eval_ndt2.csv'),
]
# import csvs
for src_path in df_paths:
    cur_host = get_simple_host()
    csv_host = src_path.name.split('_')[0]
    if cur_host != csv_host:
        print(src_path, cur_host, src_path.exists())
        EXPIRY = 86400 * 30
        # check datetime of import, if older than a day, reimport
        if src_path.exists() and (time.time() - os.path.getmtime(src_path) < EXPIRY):
            continue
        print(f'Copying {src_path} to {cur_host}')
        subprocess.run(f'scp {csv_host}:projects/ndt3/{src_path} ./data/eval_metrics', shell=True)
for src_path in ndt2_df_paths:
    cur_host = get_simple_host()
    if cur_host != csv_host:
        # check datetime of import, if older than a day, reimport
        if src_path.exists() and (os.path.getmtime(src_path) - os.path.getmtime(Path('./data/eval_metrics')) < 86400):
            continue
        print(f'Copying {src_path} to {cur_host}')
        subprocess.run(f'scp mind:projects/context_general_bci/{src_path} ./data/eval_metrics', shell=True)

eval_df = pd.concat([pd.read_csv(p) for p in df_paths])
if len(ndt2_df_paths) > 0:
    ndt2_eval_df = pd.concat([pd.read_csv(p) for p in ndt2_df_paths])
else:
    ndt2_eval_df = pd.DataFrame()

def stem_map(variant):
    if 'scratch' in variant:
        return 'NDT3 mse'
    # if 'scratch' in variant:
    #     return 'NDT3 Expert'
    return '_'.join(variant.split('-')[0].split('_')[:-1])

eval_df['variant_stem'] = eval_df.apply(lambda row: stem_map(row.variant), axis=1)
ndt2_eval_df['variant_stem'] = 'NDT2 Expert'
eval_df = pd.concat([eval_df, ndt2_eval_df, ridge_df])
if 'index' in eval_df.columns:
    eval_df.drop(columns=['index'], inplace=True)
eval_df.reset_index(inplace=True)

# drop 0s
eval_df = eval_df[eval_df['eval_r2'] != 0]
# print(eval_df[eval_df['eval_set'] == 'rtt'].sort_values('eval_r2', ascending=False))
# Unique by id
eval_df = eval_df.drop_duplicates(subset=['id']) # additional needed to not drop linear
print(eval_df[(eval_df['variant_stem'] == 'NDT3 mse') & (eval_df['eval_set'] == 'falcon_h1')]['id'].unique())
eval_df = eval_df.drop_duplicates(subset=[
    'variant_stem', 'scale_ratio', 'eval_set', 'seed'
    # multi-sweep into one best candidate
])
print(eval_df['variant_stem'].unique())
print(eval_df[eval_df['variant_stem'] == 'wf']['eval_set'].unique())

def volume_map(variant_stem):
    if 'Expert' in variant_stem:
        return 0
    elif '200h' in variant_stem:
        return 200
    elif '1kh' in variant_stem:
        return 1000
    elif '2kh' in variant_stem:
        return 2000
    elif '25h' in variant_stem:
        return 25
    elif '70h' in variant_stem:
        return 70
    elif 'min' in variant_stem:
        return 2
    else:
        return 0
eval_df['pt_volume'] = eval_df.variant_stem.apply(volume_map)

#%%

# Mean DF
columns = ['eval_r2']
mean_df = eval_df.groupby(['variant_stem', 'eval_set', 'scale_ratio', 'pt_volume'])[columns].mean().reset_index()
mean_df['scaled_task_time'] = mean_df.apply(lambda row: task_time[row['eval_set']] * row['scale_ratio'], axis=1)
print(mean_df)

f, ax = plt.subplots()
sns.scatterplot(data=mean_df, x='scaled_task_time', y='eval_r2', hue='pt_volume')
ax.set_xscale('log')
# ax.set_yscale('log')
ax.set_xlabel('Scaled Task Time (s)')
ax.set_ylabel('R2')
plt.show()
