#%%
# Evaluate continuous training on trialized data and vice versa

from typing import List
import os
import shutil
import subprocess
import logging
import sys
import argparse
logging.basicConfig(stream=sys.stdout, level=logging.INFO) # needed to get `logger` to print
from pathlib import Path
import numpy as np
import torch
import pandas as pd
from sklearn.metrics import r2_score
import pytorch_lightning as pl

# Load BrainBertInterface and SpikingDataset to make some predictions
from context_general_bci.config import RootConfig, Output
from context_general_bci.config.hp_sweep_space import sweep_space
from context_general_bci.dataset import SpikingDataset
from context_general_bci.contexts import context_registry
from context_general_bci.model import transfer_model, logger
from context_general_bci.analyze_utils import (
    stack_batch,
    get_dataloader,
    simple_unflatten_batch,
    crop_padding_from_batch,
    streaming_eval
)
from context_general_bci.utils import wandb_query_experiment, get_wandb_run, wandb_query_latest, to_device, get_simple_host
from context_general_bci.inference import load_wandb_run, get_run_config, get_best_ckpt_from_wandb_id


pl.seed_everything(0)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--eval-set", "-e", type=str, required=True, choices=['cursor_cont', 'cursor_trialized']
)
parser.add_argument(
    "--workers", "-w", type=int, default=0, help="Number of parallel workers to use for single GPU eval"
)
parser.add_argument(
    "--experiment", "-x", type=str, required=True, choices=['cursor_trialized', 'cursor_cont'], help="Experiment to evaluate"
)
parser.add_argument(
    "--queries", "-q", type=str, nargs='+', required=True, help="Queries to evaluate"
)
parser.add_argument(
    "--allowed-states", "-a", type=str, nargs='+', default=['finished'], help="Allowed states for runs"
)
parser.add_argument(
    "--scales", "-c", type=float, nargs='+', default=[0.25, 0.5, 1.0], help="Scales to evaluate"
)
args = parser.parse_args()
EVAL_SET = args.eval_set
EXPERIMENT_SET = args.experiment
PARALLEL_WORKERS = args.workers
queries = args.queries
allowed_states = args.allowed_states
SCALES = args.scales

EXPERIMENT_MAP = {
    "cursor_trialized": "v5/analyze/cursor_trialized",
    "cursor_cont": "v5/tune/cursor",
}

UNIQUE_BY = {
    "model.lr_init",
    "model.hidden_size",
    "dataset.scale_ratio",
    "seed",
}

EVAL_DATASET_FUNC_MAP = {
    'cursor_cont': 'eval_REDACT_eval_broad.*',
    'cursor_trialized': 'eval_REDACT_trialized_broad.*',
}

eval_paths = Path('~/projects/ndt3/data/analysis_metrics').expanduser()
eval_paths.mkdir(parents=True, exist_ok=True)
eval_metrics_path = eval_paths / f"cont_trialized.csv"

def load_eval_df_so_far(eval_metrics_path):
    return pd.read_csv(eval_metrics_path) if eval_metrics_path.exists() else pd.DataFrame()

def get_sweep_tags(variant: str):
    if 'scratch' in variant:
        sweep_tags = ["full_scratch"]
    else:
        sweep_tags = ['full_ft']
    return sweep_tags

def get_runs_for_query(variant: str, scale_ratio: float, experiment_set: str, project="ndt3", exp_map=EXPERIMENT_MAP):
    r"""
        variant: init_from_id
    """
    sweep_tags = get_sweep_tags(variant)
    variant_tag = f'{variant}_{int(scale_ratio * 100)}'
    print(f'Querying: {variant_tag} in sweep: {sweep_tags} for experiment: {experiment_set}')
    return wandb_query_experiment(
        exp_map[experiment_set],
        wandb_project=project,
        filter_unique_keys=UNIQUE_BY,
        **{
            "config.tag": {"$regex": variant_tag},
            "config.dataset.scale_ratio": scale_ratio,
            "config.sweep_tag": {"$in": sweep_tags},
            "state": {"$in": allowed_states},
        })

def run_list_to_df(runs, eval_set: str, experiment_set: str):
    filter_runs = [r for r in runs if 'val_kinematic_r2' in r.summary and 'max' in r.summary['val_kinematic_r2']]
    # remove runs that have "smoketest" in tag
    filter_runs = [r for r in filter_runs if 'smoketest' not in r.config['tag']]
    print(f"Filtered {len(runs)} to {len(filter_runs)} with proper metrics")
    df_dict = {
        'id': map(lambda r: r.id, filter_runs),
        'variant': map(lambda r: r.config['tag'], filter_runs),
        'scale_ratio': map(lambda r: r.config['dataset']['scale_ratio'], filter_runs),
        'eval_set': map(lambda r: eval_set, filter_runs),
        'experiment_set': map(lambda r: experiment_set, filter_runs),
        'val_kinematic_r2': map(lambda r: r.summary['val_kinematic_r2']['max'], filter_runs),
        'sweep': list(map(lambda r: get_sweep_tags(r.config['tag'])[0], filter_runs)), # cast to not exhaust when we then query
    }
    # Add sweep HPs
    def nested_get_from_config(config, param: List[str]):
        if len(param) > 1:
            return nested_get_from_config(config[param[0]], param[1:])
        return config[param[0]]
    unique_sweeps = set(df_dict['sweep'])
    for sweep_name in unique_sweeps:
        for p in sweep_space[sweep_name].keys():
            # For some reason if we don't cast, early params get overwritten..
            df_dict[p] = list(map(lambda r: nested_get_from_config(r.config, p.split('.')), filter_runs))
    # Get the precise eval metric associated with this ckpt
    run_histories = [r.history() for r in filter_runs]
    eval_reports = [
        rh.loc[rh['val_kinematic_r2'].idxmax()]['eval_kinematic_r2'] for rh in run_histories
    ]
    df_dict['eval_report'] = eval_reports
    # df_dict['eval_report'] = map(lambda r: r.summary['eval_kinematic_r2']['max'] if 'eval_kinematic_r2' in r.summary else 0, filter_runs)
    return pd.DataFrame(df_dict)

def get_best_run_per_sweep(run_df, metric='val_kinematic_r2'):
    # First group by variant and HPs, and average over seeds.
    if 'seed' in run_df:
        hp_columns = [col for col in run_df.columns if col not in ['id', 'variant', 'eval_set', 'scale_ratio', 'seed', 'val_kinematic_r2', 'eval_report']]
        id_columns = ['variant']
        group_columns = [*hp_columns, *id_columns]
        seed_averaged_df = run_df.groupby(group_columns)[metric].mean().reset_index()
        aug_df = pd.merge(run_df, seed_averaged_df, on=group_columns, suffixes=('', '_seed_avg'))
        filter_metric = f'{metric}_seed_avg'
        run_df = aug_df.groupby('variant').apply(lambda x: x[x[filter_metric] == x[filter_metric].max()]).reset_index(drop=True)
    else: # Then re-group by variant and filter for the best HP.
        run_df = run_df.groupby(['variant']).apply(lambda x: x.nlargest(1, metric)).reset_index(drop=True)
    return run_df

def get_run_df_for_query(variant: str, scale_ratio: float, experiment_set: str, eval_set: str, metric='val_kinematic_r2', **kwargs):
    runs = get_runs_for_query(variant, scale_ratio, experiment_set, **kwargs)
    run_df = run_list_to_df(runs, eval_set=eval_set, experiment_set=experiment_set)
    return get_best_run_per_sweep(run_df, metric=metric)

def get_eval_df(queries):
    query_dfs = []
    scales = SCALES
    for query in queries:
        query_dfs.extend([get_run_df_for_query(query, scale_ratio, EXPERIMENT_SET, EVAL_SET) for scale_ratio in scales])
    eval_df = pd.concat(query_dfs).reset_index(drop=True)
    eval_df['eval_r2'] = 0.
    return eval_df

def trim_df(df: pd.DataFrame, df_so_far: pd.DataFrame) -> pd.DataFrame:
    # Delete the data from eval queue that already exists in so_far
    if len(df_so_far):
        if 'index' in df_so_far:
            df_so_far.drop(columns=['index'], inplace=True)
        # df_so_far zero to nan
        df_so_far['eval_r2'] = df_so_far['eval_r2'].replace(0, np.nan)
        # df_so_far drop nan
        df_so_far = df_so_far.dropna(subset=['eval_r2'])
        # ID + Eval set must uniquely identify
        df = df.merge(df_so_far[['id', 'eval_set']], on=['id', 'eval_set'], how='left', indicator=True)
        df = df[df['_merge'] == 'left_only'].drop(columns=['_merge']).reset_index(drop=True)
    return df

def get_single_eval(cfg: RootConfig, src_model, dataset, stream=False, device=torch.device('cuda')):
    pl.seed_everything(0)
    dataset = SpikingDataset(cfg.dataset, use_augment=False)
    dataset.cfg.max_tokens = 32768
    dataset.cfg.max_trial_length = 1500
    dataset.cfg.max_length_ms = 30000
    dataset.set_no_crop(True) # not expecting to need extra tokens but should fail if so (these are trialized cursor/rtt, all fairly short (a few s))
    dataset.subset_split(splits=['eval'])
    dataset.build_context_index()
    data_attrs = dataset.get_data_attrs()
    print("Eval length: ", len(dataset))
    model = transfer_model(src_model, cfg.model, data_attrs)
    model.eval()
    model = model.to(device)

    if stream:
        r"""
            Evaluation can be done without streaming, to better reflect training-time metrics
            At test time, though, we often are interested in evaluating what performance would look like at test time.
            That is, streaming, with a second of history.
            This is going to be much _slower_ (especially if we don't have batching)
            This may be higher than training metrics because model always has full history,
            but may also be lower because model may be overfit to trial structure (chopping is typically good for cirucmventing.)
            The current implementation will directly bleed data when the actual _session_ changes. (which should be negligible)
        """
        logger.info("Streaming eval: Assuming chopped, continuous data.")
        outputs, r2, mse, loss = streaming_eval(
            model,
            dataset,
            stream_buffer_s=dataset.cfg.REDACT_co.chop_size_ms / 1000,
            temperature=0.,
            use_kv_cache=True if not cfg.dataset.sparse_constraints else False,
            autoregress_cue=False,
            skip_cache_reset=True,
            use_mask_in_metrics=True, # For comparing with eval_scaling
        )
    else:
        dataloader = get_dataloader(dataset, batch_size=16)
        batch_outputs = []
        mask_kin = torch.ones(cfg.dataset.max_length_ms // cfg.dataset.bin_size_ms, device=device)
        for batch in dataloader:
            batch = to_device(batch, device)
            out = model.predict_simple_batch(batch, kin_mask_timesteps=mask_kin)
            del out[Output.behavior_loss]
            del out['covariate_labels']
            del out[Output.behavior_query_mask]
            out_unflat = simple_unflatten_batch(out, ref_batch=batch)
            batch_outputs.append(out_unflat)
        outputs = stack_batch(batch_outputs, merge_tensor='cat')
        outputs = crop_padding_from_batch(outputs) # Note this is redundant with bhvr mask but needed if bhvr_mask isn't available
    from context_general_bci.analyze_utils import stream_to_tensor_dict
    plot_dict = stream_to_tensor_dict(outputs, model)
    # Need to unflatten for variance weighted
    if Output.behavior_mask.name not in plot_dict['kin'].keys():
        masks = np.ones_like(plot_dict['kin'][Output.behavior_pred.name], dtype=bool)
    else:
        masks = plot_dict['kin'][Output.behavior_mask.name]
        if not masks.any() or not (masks.all(1) ^ (~masks).any(1)).all():
            print("Behavior mask is not as expected, tensordict error?")
            masks = outputs[Output.behavior_mask].cpu()
            if not masks.any() or not (masks.all(1) ^ (~masks).any(1)).all():
                print("Behavior mask is still not as expected, aborting")
                return
    pred, true = plot_dict['kin'][Output.behavior_pred.name], plot_dict['kin'][Output.behavior.name]
    masks = masks.any(-1)
    pred = pred[masks]
    true = true[masks]
    r2 = r2_score(true.cpu().numpy(), pred.cpu().numpy(), multioutput='variance_weighted')
    print(f"R2 over {len(pred)} samples: {r2:.3f}")
    return r2

def process_row_wrapper(df_itertuple):
    assert len(df_itertuple) >= 8, "8 or 9, if falcon, check eval-df.columns + 1 for df index"
    index, run_id, variant, scale_ratio, eval_set, val_kinematic_r2, eval_report, *_ = df_itertuple
    return process_row(index, run_id, eval_set)

def process_row(index: int, run_row_id: str, run_row_eval_set: str):
    print(f'Checking local registry: {os.getpid(), id(context_registry), len(context_registry.search_index)}')
    context_registry.query(alias='dummy') # needed to not re-init the registry in mp stream
    print(f"Checking args: {index}, {run_row_id}, {run_row_eval_set}")
    eval_set = EVAL_DATASET_FUNC_MAP[run_row_eval_set]
    run = get_wandb_run(run_row_id)
    device_selected = index % torch.cuda.device_count()
    device = torch.device(f'cuda:{device_selected}')

    try:
        src_model, cfg, data_attrs = load_wandb_run(run, tag='val_kinematic_r2')
    except Exception as e:
        print(f"Failed to load run {run_row_id} (missing ckpt?): {e}, probably missing ckpt")
        return 0
    assert cfg.dataset.eval_ratio == 1.0, "Expecting full eval set"
    cfg.dataset.eval_datasets = [eval_set]
    cfg.dataset.datasets = cfg.dataset.eval_datasets # Ensure alignment
    # Avoid reprocessing dataset with run's dataset config
    # breakpoint()
    if run_row_eval_set == 'cursor_trialized':
        cfg.dataset.REDACT_co.respect_trial_boundaries = True
        cfg.dataset.REDACT_co.chop_size_ms = 0
    elif run_row_eval_set == 'cursor_cont':
        cfg.dataset.REDACT_co.respect_trial_boundaries = False
        cfg.dataset.REDACT_co.chop_size_ms = 1000
    else:
        raise ValueError(f"Unknown eval set: {run_row_eval_set}")
    dataset = SpikingDataset(cfg.dataset, use_augment=False)
    dataset.subset_split(splits=['eval'])
    cfg.model.task.outputs = [
        Output.behavior,
        Output.behavior_pred,
    ]
    return get_single_eval(cfg, src_model, dataset=dataset, device=device, stream='cont' in run_row_eval_set)

def exec_eval():
    eval_df = get_eval_df(queries)
    eval_df_so_far = load_eval_df_so_far(eval_metrics_path)
    eval_df = trim_df(eval_df, eval_df_so_far)
    print(eval_df['variant'].unique())
    eval_df.reset_index(drop=True, inplace=True)

    if PARALLEL_WORKERS:
        from torch.multiprocessing import Pool, set_start_method
        set_start_method('spawn', force=True) # https://github.com/pytorch/pytorch/issues/3492#issuecomment-382660636
        torch.multiprocessing.set_sharing_strategy('file_system')
        with Pool(processes=PARALLEL_WORKERS) as pool:  # Adjust pool size if needed
            results = pool.map(process_row_wrapper, eval_df.itertuples(index=True, name=None))
    else:
        results = [process_row_wrapper(row) for row in eval_df.itertuples(index=True)]
    # breakpoint()
    for idx, result in enumerate(results):
        eval_df.at[idx, 'eval_r2'] = result
        print(eval_df.iloc[idx])
    # merge again
    # reload eval_df_so_far in case some commits occurred during eval
    eval_df_so_far = load_eval_df_so_far(eval_metrics_path)
    eval_df = pd.concat([eval_df, eval_df_so_far]).reset_index(drop=True)
    eval_df = eval_df.drop_duplicates(subset=['variant', 'seed','eval_set','experiment_set'], keep='first').reset_index(drop=True)

    eval_df.to_csv(eval_metrics_path, index=False)


if __name__ == "__main__":
    exec_eval()