# TODO:
from model import Decoder_ as Decoder
from pl_wrapper import LightningModel_ as LightningModel

import argparse
from functools import partial
import lightning
from lightning import Trainer
from lightning.pytorch.callbacks import (
    LearningRateMonitor,
    ModelCheckpoint,
    BaseFinetuning,
)
from multiprocessing import Process
from os.path import join as opj
import sys


NEURAL_ENCODER_CKPT_PATH = "/data/XXXXXX/nejm-brain-to-text/optimization/gru_ctc/apricot-sweep-47/best_model_loss.ckpt"
CKPTS_DIR = "/data/data/XXXXXXX/whisper-based/ckpts-vanilla"
BS = 64


parser = argparse.ArgumentParser()
parser.add_argument(
    "--data", choices=["willett", "b2txt25", "cross"], default="b2txt25"
)
parser.add_argument("--wandb", "-w", action="store_true")
parser.add_argument("--processes", "-p", type=int)
parser.add_argument("--devices", "-d", nargs="+", type=int)
args = parser.parse_args()

PROJECT = f"vanilla-whisper-{args.data}"

if args.wandb:
    import wandb
    from lightning.pytorch.loggers import WandbLogger

    # W&B config
    wandb.login()
    wandb_config = {
        "method": "grid",
        "parameters": {
            "whisper_name": {"values": ["openai/whisper-tiny.en"]},
            "learning_rate": {"values": [1e-3]},
            "learning_rate_min": {"values": [1e-5]},
            "weight_decay": {"values": [1e-5]},
            "scheduler": {"values": ["cosine"]},
            "max_epochs": {"values": [100]},
            "kernel_size_1": {"values": [7]},
            "stride_2": {"values": [2]},
            "dropout": {"values": [0.4]},
            "last_phoneme_layer": {"values": [2]},
            "attn_window_size": {"values": [29]},
            "ce_coeff": {"values": [1.0]},
            "ctc_coeff": {"values": [1.0]},
            "unfreeze_whisper_decoder_at_epoch": {"values": [-1]},
            "seed": {"values": [1]},
        },
    }
    if args.data != "cross":
        wandb_config["parameters"] |= {
            "encodings": {"values": ["sin"]},
            "day_projections": {"values": [True]},
            "r": {"values": [16]},
        }
    else:
        wandb_config["parameters"] |= {
            # Brain-to-Text '25
            "encodings_b2txt25": {"values": ["sin"]},
            "day_projections_b2txt25": {"values": [True]},
            "r_b2txt25": {"values": [16]},
            # Willett
            "encodings_willett": {"values": ["learn"]},
            "day_projections_willett": {"values": [True]},
            "r_willett": {"values": [0]},
        }
    sweep_id = wandb.sweep(wandb_config, project=PROJECT)


def train(device_id):
    """Wrapper for training."""

    # set hyperparameters
    if args.wandb:
        run = wandb.init()
        config = wandb.config
    else:
        # config = {
        #     "whisper_name": "openai/whisper-tiny.en",
        #     "learning_rate": 1e-3,
        #     "learning_rate_min": 1e-5,
        #     "weight_decay": 1e-5,
        #     "scheduler": "cosine",
        #     "max_epochs": 100,
        #     "kernel_size_1": 7,
        #     "stride_2": 2,
        #     "encodings": "sin",
        #     "day_projections": True,
        #     "r": 16,
        #     "dropout": 0.4,
        #     "last_phoneme_layer": 2,
        #     "attn_window_size": 29,
        #     "ce_coeff": 1.0,
        #     "ctc_coeff": 1.0,
        #     "unfreeze_whisper_decoder_at_epoch": -1,
        #     "seed": 1,
        # }
        raise NotImplementedError()  # FIXME:

    max_epochs = config["max_epochs"]
    ce_coeff = config["ce_coeff"]
    ctc_coeff = config["ctc_coeff"]
    unfreeze_whisper_decoder_at_epoch = config["unfreeze_whisper_decoder_at_epoch"]
    seed = config["seed"]

    # set the global seed
    lightning.seed_everything(seed=seed, workers=True)

    if args.data == "willett":
        sys.path.append("/data/data/XXXXXX/speech_decoding_BCI")
        from config import DATASET_AFTERGO_TRIALS_ZSCORE
        from dataset import getDatasetLoaders_V4

        # get data loaders for Willett's data
        train_loader, test_loader, _, _ = getDatasetLoaders_V4(
            DATASET_AFTERGO_TRIALS_ZSCORE, BS, include_prego=True
        )
        config["num_features"] = 256
        config["num_days"] = train_loader.dataset.n_days

    elif args.data == "b2txt25":
        from dataset_XXXXXX_backup import getDatasetLoaders

        # get data loaders for Brain-to-text '25 data
        train_loader, test_loader, _ = getDatasetLoaders(
            BATCH_SIZE=BS, SHUFFLE_TRAIN=True, SEED=seed
        )
        config["num_features"] = 512
        config["num_days"] = train_loader.dataset.n_days

    elif args.data == "cross":
        from cross_dataset import get_data_loaders

        # get data loaders for Brain-to-text '25/Willett combined data
        train_loader, test_loader = get_data_loaders()
        config["num_features_b2txt25"] = 512  # Brain-to-Text '25
        config["num_days_b2txt25"] = 45
        config["num_features_willett"] = 512  # Willett
        config["num_days_willett"] = 24

    # initialize the model
    model = LightningModel(dict(config) if args.wandb else config)

    # Lightning callbacks
    ckpt_best_wer = ModelCheckpoint(
        dirpath=opj(CKPTS_DIR, sweep_id, run.name) if args.wandb else CKPTS_DIR,
        filename="{epoch}-{val_wer:.4f}",
        monitor="val_wer",
        mode="min",
    )
    ckpt_best_per = ModelCheckpoint(
        dirpath=opj(CKPTS_DIR, sweep_id, run.name) if args.wandb else CKPTS_DIR,
        filename="{epoch}-{val_per:.4f}",
        monitor="val_per",
        mode="min",
    )

    lr_monitor = LearningRateMonitor(logging_interval="epoch")

    class UnfreezeWhisperDecoder(BaseFinetuning):
        def __init__(self, unfreeze_at_epoch):
            super().__init__()
            self._unfreeze_at_epoch = unfreeze_at_epoch

        def freeze_before_training(self, pl_module):
            pass

        def finetune_function(self, pl_module, current_epoch, optimizer):
            # unfreeze the Whisper decoder
            if current_epoch == self._unfreeze_at_epoch:
                pl_module.model.unfreeze_whisper_decoder()
                print("Whisper decoder is now training...")

    callbacks = [lr_monitor]

    if ce_coeff > 0.0:
        callbacks += [ckpt_best_wer]

        if unfreeze_whisper_decoder_at_epoch >= 0:
            callbacks += [
                UnfreezeWhisperDecoder(
                    unfreeze_at_epoch=unfreeze_whisper_decoder_at_epoch
                )
            ]

    if ctc_coeff > 0.0:
        callbacks += [ckpt_best_per]

    trainer = Trainer(
        max_epochs=max_epochs,
        callbacks=callbacks,
        accelerator="gpu",
        devices=[device_id],
        logger=WandbLogger() if args.wandb else None,
        log_every_n_steps=1,
    )

    # train
    trainer.fit(model, train_loader, test_loader)

    if args.wandb:
        run.finish()


if args.wandb:
    # parallelize sweep on multiple processes
    def run_agent(device_id):
        wandb.agent(
            sweep_id, function=partial(train, device_id=device_id), project=PROJECT
        )

    # create a list to store the processes
    processes = []

    # start the parallel processes
    process_devices = []
    for gpu in args.devices:
        process_devices += args.processes // len(args.devices) * [gpu]
    for device_id in process_devices:
        process = Process(target=run_agent, args=(device_id,))
        process.start()
        processes.append(process)

    # wait for all processes to finish
    for process in processes:
        process.join()
else:
    train(args.devices[0])
