"""DINO Pretraining for PSG signals."""
from pprint import pprint
import os
from argparse import ArgumentParser, Namespace
import datetime
from dateutil import tz
import random
import numpy as np
import torch
import warnings
from datetime import timedelta
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy

from melp.datasets.pretrain_datamodule import SleepDataModule
from melp.models.dino_model_cls import DINOCLSModel
from wav2sleep.config import *
from train_config import *

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')


def main(hparams: Namespace):
    # Init trainer
    now = datetime.datetime.now(tz.tzlocal())
    extension = now.strftime("%Y_%m_%d_%H_%M_%S")
    extension = f"dino_{hparams.psg_encoder_name}_bz{hparams.batch_size}_{extension}"
    
    ckpt_dir = os.path.join(CKPT_PATH, f"logs/dino/ckpts/{extension}")
    os.makedirs(ckpt_dir, exist_ok=True)
    
    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            monitor="val/loss", dirpath=ckpt_dir,
            save_last=True, every_n_epochs=2, mode="min", save_top_k=-1,
            save_on_train_epoch_end=False, auto_insert_metric_name=True
        ),
    ]
    
    logger_dir = os.path.join(CKPT_PATH, "logs/dino")
    os.makedirs(logger_dir, exist_ok=True)
    wandb_logger = WandbLogger(
        project=f"dino_{hparams.psg_encoder_name}_bz{hparams.batch_size}",
        save_dir=logger_dir, name=extension
    )

    strategy = DDPStrategy(
        find_unused_parameters=True,
        static_graph=False,
        timeout=timedelta(minutes=15),
    )

    trainer = Trainer(
        max_epochs=hparams.max_epochs,
        accelerator="gpu",
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        devices=hparams.num_devices,
        num_nodes=hparams.num_nodes,
        gradient_clip_val=3.0,
        precision=hparams.precision,
        gradient_clip_algorithm="norm",
        strategy=strategy,
        callbacks=callbacks,
        logger=wandb_logger,
        log_every_n_steps=10,
    )

    # Init datamodule
    train_edf_cols = TRAIN_EDF_COLS_MAE
    hparams.num_leads = len(train_edf_cols)
    
    dm = SleepDataModule(
        is_pretrain=1,
        csv_dir=SPLIT_DATA_FOLDER,
        train_edf_cols=train_edf_cols,
        batch_size=hparams.batch_size,
        num_workers=hparams.num_workers,
        data_pct=hparams.train_data_pct,
        window_size=30,
        sample_rate=64,
        val_dataset_list=hparams.val_dataset_list,
        data_source=hparams.data_source,
        include_datasets=hparams.include_datasets,
    )

    # Init model
    model = DINOCLSModel(**vars(hparams))
    model.training_steps_per_epoch = len(dm.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
    model.teacher_temp_warmup_iters = model.training_steps_per_epoch * 0.1 * hparams.max_epochs
    
    pprint(vars(hparams))
    
    if hparams.ckpt_path:
        trainer.fit(model, datamodule=dm, ckpt_path=hparams.ckpt_path)
    else:
        trainer.fit(model, datamodule=dm)


if __name__ == '__main__':
    parser = ArgumentParser(description="DINO Pretraining for PSG signals.")
    
    # Model
    parser.add_argument("--psg_encoder_name", type=str, default="vit_small",
                        choices=["vit_nano", "vit_tiny", "vit_small", "vit_base"])
    parser.add_argument("--patch_size_time", type=int, default=64)
    parser.add_argument("--patch_size_ch", type=int, default=4)
    parser.add_argument("--lead_wise", type=int, default=0)
    
    # DINO specific
    parser.add_argument("--dino_out_dim", type=int, default=2048)
    parser.add_argument("--dino_patch_out_dim", type=int, default=2048)
    parser.add_argument("--dino_hidden_dim", type=int, default=2048)
    parser.add_argument("--dino_bottleneck_dim", type=int, default=256)
    parser.add_argument("--koleo_lambda", type=float, default=0.0)
    parser.add_argument("--ibot_lambda", type=float, default=0.0)
    
    # Training
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--max_epochs", type=int, default=30)
    parser.add_argument("--num_workers", type=int, default=64)
    parser.add_argument("--num_devices", type=int, default=4)
    parser.add_argument("--num_nodes", type=int, default=1)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--precision", type=str, default="32-true")
    parser.add_argument("--train_data_pct", type=float, default=1.0)
    
    # Data
    parser.add_argument("--val_dataset_list", default=PRETRAIN_VAL_DATASET_LIST)
    parser.add_argument("--data_source", type=str, default="auto")
    parser.add_argument("--include_datasets", type=str, nargs="*", default=None)
    parser.add_argument("--simclr_augmentation", type=str, default="chan_then_pcspan")
    
    # Checkpoint
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--wandb_proj_name", type=str, default="dino_pretrain")

    hparams = parser.parse_args()
    
    random.seed(hparams.seed)
    np.random.seed(hparams.seed)
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    seed_everything(hparams.seed)
    main(hparams)
