import hydra
import torch
import copy
import re
import os
import pytorch_lightning as pl

from pathlib import Path
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from flow.models.SiT_models import SiT
from flow.models.SiT_models_falcon import SiT_falcon, filter_state_dict
from context_general_bci.config import RootConfig, propagate_config
from context_general_bci.dataset import SpikingDataset, SpikingDataModule
from context_general_bci.model import load_from_checkpoint, load_same_config

@ hydra.main(version_base=None, config_path='context_general_bci/config', config_name="config")
def main(cfg: RootConfig):
    cfg.tag = 'm1' # /'m2'
    cfg.experiment_set = 'falcon'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    phase = 'ft'
    
    propagate_config(cfg)

    ft_ratio_list = [1.0]

    for ft_ratio in ft_ratio_list:
        ft_ratio_name = ft_ratio

        seed_list = list(range(0, 5))
        for seed in seed_list:
            print(f"Training with seed {seed}")
            pl.seed_everything(seed=seed, workers=True)

            dataset = SpikingDataset(cfg.dataset)
            if cfg.debug:
                breakpoint()
            dataset.build_context_index()
            if cfg.dataset.eval_datasets:
                eval_dataset = copy.deepcopy(dataset)
                eval_dataset.subset_split(splits=['eval'], keep_index=True)
            dataset.subset_split(keep_index=True)
            if cfg.dataset.scale_limit_per_session or cfg.dataset.scale_limit_per_eval_session:
                dataset.subset_scale(
                    limit_per_session=cfg.dataset.scale_limit_per_session,
                    limit_per_eval_session=cfg.dataset.scale_limit_per_eval_session,
                    keep_index=True
                )
            elif cfg.dataset.scale_limit:
                dataset.subset_scale(limit=cfg.dataset.scale_limit, keep_index=True)
            elif cfg.dataset.scale_ratio:
                dataset.subset_scale(ratio=cfg.dataset.scale_ratio, keep_index=True)
            train, val = dataset.create_tv_datasets(ft_ratio=ft_ratio, phase=phase)
            print(f"Training on {len(train)} examples")
            data_attrs = dataset.get_data_attrs()

            # === Train ===
            num_workers = len(os.sched_getaffinity(0)) # If this is set too high, the dataloader may crash.
            if num_workers == 0:
                print("Num workers is 0, DEBUGGING.")
            print("Preparing to fit...")
            val_datasets = [val]
            if cfg.dataset.eval_datasets:
                val_datasets.append(eval_dataset)
            data_module = SpikingDataModule(
                cfg.train.batch_size,
                num_workers,
                train, val_datasets
            )
            train_dataloader = data_module.train_dataloader()
            val_dataloader = data_module.val_dataloader()

            # sit
            target_init_ckpt = f'./local_data/checkpoint/falcon/{cfg.tag}/ndt2_{cfg.tag}_pretrain.pth'
            target_model = load_from_checkpoint(target_init_ckpt, cfg=cfg.model, data_attrs=data_attrs)
            target_model.to(device)
            target_model.freeze_non_embed()

            # conditional model
            cond_model = load_same_config(target_init_ckpt, cfg=cfg.model, data_attrs=data_attrs)
            cond_model.to(device)
            
            invert_flag = False
            pred_len = 1
            latent_dim = cfg.model.hidden_size
            sit_model = SiT(
                in_channels=cfg.dataset.neurons_per_token,
                window_size=pred_len,
                hidden_size=latent_dim,
                out_dim=cfg.dataset.neurons_per_token,
                beh_dim=cfg.dataset.behavior_dim,

                num_heads=cfg.model.transformer.n_heads,
                depth=4,
                mlp_ratio=2.0,
                model_config=None,
                target_latent_config=None,
                cond_model=None,
                beh_config=None,
                invert_flag=invert_flag,
            )

            dict_cfg = {
                'cond_model': cond_model,
                'target_ndt': target_model,
                'sit': sit_model,
                'latent_dim': latent_dim,
                'phase': phase,
            }
            sit_falcon = SiT_falcon(dict_cfg)
            sit_falcon.to(device)

            # load pre-trained model
            pre_ckpt_list = list(Path(f'./checkpoints/{cfg.experiment_set}/{cfg.tag}/pre_train/{seed}').glob("best-kinematic-r2-fm-epoch=*-val_kinematic_r2=*.ckpt"))

            def extract_r2(filepath):
                match = re.search(r"val_kinematic_r2=([0-9.]+)\.ckpt", filepath.name)
                return float(match.group(1)) if match else -1

            sorted_pre_ckpt = sorted(pre_ckpt_list, key=extract_r2, reverse=True)
            if len(sorted_pre_ckpt) > 0:
                pre_ckpt = sorted_pre_ckpt[0]
                print(f"Loading pre-trained model from {pre_ckpt}")
            else:
                AssertionError("No pre-trained model found.")
            
            sit_falcon.freeze_sit_and_readout()
            sit_falcon.to(device)

            early_stop = EarlyStopping(
                monitor='val_loss',
                patience=cfg.train.patience,
                mode='min',
                verbose=True
            )
            call_backs = [early_stop]
            checkpoint_callback = ModelCheckpoint(
                monitor='val_loss',
                dirpath=f'./checkpoints/{cfg.experiment_set}/{cfg.tag}/{phase}/ft_ratio_{ft_ratio_name}/{seed}',
                filename='best-loss-fm-{epoch:02d}-{val_loss:.2f}',
                save_top_k=1,
                mode='min',
            )
            call_backs.append(checkpoint_callback)
            checkpoint_callback = ModelCheckpoint(
                monitor='val_kinematic_r2',
                dirpath=f'./checkpoints/{cfg.experiment_set}/{cfg.tag}/{phase}/ft_ratio_{ft_ratio_name}/{seed}',
                filename='best-kinematic-r2-fm-{epoch:02d}-{val_kinematic_r2:.2f}',
                save_top_k=1,
                mode='max',
            )
            call_backs.append(checkpoint_callback)

            csv_logger = CSVLogger(
                save_dir=f"./lightning_logs/falcon/{cfg.tag}/{phase}/ft_ratio_{ft_ratio_name}",
                name=f"{cfg.experiment_set}_{cfg.tag}_{seed}",
            )

            trainer = pl.Trainer(
                accelerator="gpu", 
                callbacks=call_backs,
                logger=csv_logger,
            )

            trainer.fit(
                sit_falcon,
                train_dataloader, val_dataloader,
            )

    pass

if __name__ == "__main__":
    main()
    pass