import os
import hydra
import torch
import pytorch_lightning as pl
import numpy as np
import copy

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from context_general_bci.config import RootConfig, propagate_config, ModelTask, DataKey, Metric
from context_general_bci.dataset import SpikingDataset, SpikingDataModule
from context_general_bci.model import load_from_checkpoint, load_same_config
from flow.models.SiT_models import SiT
from flow.models.SiT_models_falcon import SiT_falcon

def set_seed(seed):
    """
    Set random seed for reproducibility
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

@hydra.main(version_base=None, config_path='context_general_bci/config', config_name="config")
def main(cfg: RootConfig):
    cfg.tag = 'm1'# /'m2'
    cfg.experiment_set = 'falcon'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    propagate_config(cfg)

    seed_list = list(range(0, 5))
    for seed in seed_list:
        print(f"Training with seed {seed}")
        pl.seed_everything(seed, workers=True)

        dataset = SpikingDataset(cfg.dataset)
        if cfg.debug:
            breakpoint()
        dataset.build_context_index()
        if cfg.dataset.eval_datasets:
            eval_dataset = copy.deepcopy(dataset)
            eval_dataset.subset_split(splits=['eval'], keep_index=True)
        dataset.subset_split(keep_index=True)
        if cfg.dataset.scale_limit_per_session or cfg.dataset.scale_limit_per_eval_session:
            dataset.subset_scale(
                limit_per_session=cfg.dataset.scale_limit_per_session,
                limit_per_eval_session=cfg.dataset.scale_limit_per_eval_session,
                keep_index=True
            )
        elif cfg.dataset.scale_limit:
            dataset.subset_scale(limit=cfg.dataset.scale_limit, keep_index=True)
        elif cfg.dataset.scale_ratio:
            dataset.subset_scale(ratio=cfg.dataset.scale_ratio, keep_index=True)
        train, val = dataset.create_tv_datasets()
        print(f"Training on {len(train)} examples")
        data_attrs = dataset.get_data_attrs()

        # === Train ===
        num_workers = len(os.sched_getaffinity(0)) # If this is set too high, the dataloader may crash.
        if num_workers == 0:
            print("Num workers is 0, DEBUGGING.")
        print("Preparing to fit...")
        val_datasets = [val]
        if cfg.dataset.eval_datasets:
            val_datasets.append(eval_dataset)
        data_module = SpikingDataModule(
            cfg.train.batch_size,
            num_workers,
            train, val_datasets
        )
        train_dataloader = data_module.train_dataloader()
        val_dataloader = data_module.val_dataloader()

        
        # ndt2 model
        cfg.init_ckpt = f'./local_data/checkpoint/falcon/{cfg.tag}/ndt2_{cfg.tag}_pretrain.pth'
        target_model = load_from_checkpoint(cfg.init_ckpt, cfg=cfg.model, data_attrs=data_attrs)
        target_model.to(device)
        target_model.freeze_non_embed()

        # conditional model
        cond_model = load_same_config(cfg.init_ckpt, cfg=cfg.model, data_attrs=data_attrs)
        cond_model.to(device)
        cond_model.freeze_readin_and_out()

        # sit
        invert_flag = False
        pred_len = 1
        latent_dim = cfg.model.hidden_size
        sit_model = SiT(
            in_channels=cfg.dataset.neurons_per_token,
            window_size=pred_len,
            hidden_size=latent_dim,
            out_dim=cfg.dataset.neurons_per_token,
            beh_dim=cfg.dataset.behavior_dim,

            num_heads=cfg.model.transformer.n_heads,
            depth=4,
            mlp_ratio=2.0,
            model_config=None,
            target_latent_config=None,
            cond_model=None,
            beh_config=None,
            invert_flag=invert_flag,
        )

        phase = 'pre_train'
        dict_cfg = {
            'cond_model': cond_model,
            'target_ndt': target_model,
            'sit': sit_model,
            'latent_dim': latent_dim,
            'phase': phase,
        }
        sit_falcon = SiT_falcon(dict_cfg)
        sit_falcon.to(device)

        early_stop = EarlyStopping(
            monitor='val_loss',
            patience=cfg.train.patience,
            mode='min',
            verbose=True
        )
        call_backs = [early_stop]
        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',
            dirpath=f'./checkpoints/{cfg.experiment_set}/{cfg.tag}/{phase}/{seed}',
            filename='best-loss-fm-{epoch:02d}-{val_loss:.2f}',
            save_top_k=1,
            mode='min',
        )
        call_backs.append(checkpoint_callback)
        checkpoint_callback = ModelCheckpoint(
            monitor='val_kinematic_r2',
            dirpath=f'./checkpoints/{cfg.experiment_set}/{cfg.tag}/{phase}/{seed}',
            filename='best-kinematic-r2-fm-{epoch:02d}-{val_kinematic_r2:.2f}',
            save_top_k=1,
            mode='max',
        )
        call_backs.append(checkpoint_callback)

        csv_logger = CSVLogger(
            save_dir=f"./lightning_logs/falcon/{cfg.tag}/{phase}",
            name=f"{cfg.experiment_set}_{cfg.tag}_{seed}",
        )

        trainer = pl.Trainer(
            accelerator="gpu", max_epochs=1000, 
            callbacks=call_backs,
            logger=csv_logger,
        )

        trainer.fit(
            sit_falcon,
            train_dataloader, val_dataloader,
        )

    pass

if __name__ == "__main__":
    main()
    pass