"""DINO Finetuning for downstream tasks."""
from pprint import pprint
import os
from argparse import ArgumentParser, Namespace
import datetime
from dateutil import tz
import random
import numpy as np
import torch
import warnings
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from melp.datasets.pretrain_datamodule import SleepDataModule
from melp.models.dino_model_cls import DINOCLSModel
from melp.models.ssl_finetuner import SSLFineTuner
from wav2sleep.config import *
from train_config import *

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')


def main(hparams: Namespace):
    # Init trainer
    now = datetime.datetime.now(tz.tzlocal())
    timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
    
    exp_name = "finetune_dino_full" if hparams.finetune_backbone else "finetune_dino"
    run_name = f"{hparams.eval_label}_{hparams.downstream_dataset_name}_{timestamp}"
    
    ckpt_dir = os.path.join(CKPT_PATH, f"logs/{exp_name}/ckpts/{run_name}")
    os.makedirs(ckpt_dir, exist_ok=True)
    
    callbacks = [
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            monitor="val_auc", dirpath=ckpt_dir,
            save_last=False, mode="max", save_top_k=1,
            auto_insert_metric_name=True
        ),
    ]
    
    if hparams.early_stopping:
        callbacks.append(EarlyStopping(
            monitor="val_auc", patience=hparams.early_stopping_patience,
            mode="max", verbose=True
        ))
    
    logger_dir = os.path.join(CKPT_PATH, f"logs/{exp_name}")
    os.makedirs(logger_dir, exist_ok=True)
    wandb_logger = WandbLogger(
        project=f"{exp_name}_sleepuni", save_dir=logger_dir, name=run_name
    )
    
    trainer = Trainer(
        max_steps=hparams.max_steps,
        accelerator="gpu",
        accumulate_grad_batches=hparams.accumulate_grad_batches,
        deterministic=True,
        devices=hparams.num_devices,
        strategy="ddp_find_unused_parameters_true",
        precision=hparams.precision,
        callbacks=callbacks,
        logger=wandb_logger
    )

    # Init datamodule
    train_edf_cols = TRAIN_EDF_COLS_MAE
    
    datamodule = SleepDataModule(
        is_pretrain=0,
        data_pct=hparams.train_data_pct,
        downstream_dataset_name=hparams.downstream_dataset_name,
        csv_dir=SPLIT_DATA_FOLDER,
        train_edf_cols=train_edf_cols,
        event_cols=hparams.eval_label,
        batch_size=hparams.batch_size,
        num_workers=hparams.num_workers,
        sample_rate=hparams.sample_rate,
        window_size=30,
        data_source=hparams.data_source,
        include_datasets=hparams.include_datasets,
        n_train_samples=getattr(hparams, 'n_train_samples', None),
        val_batch_size=getattr(hparams, 'val_batch_size', None),
        val_data_pct=getattr(hparams, 'val_data_pct', None),
    )

    # Get num_classes from dataset
    train_dataset = datamodule.train_dataloader().dataset
    if hasattr(train_dataset, 'dataset'):  # It's a Subset
        hparams.num_classes = train_dataset.dataset.num_classes
    else:
        hparams.num_classes = train_dataset.num_classes
    
    hparams.training_steps_per_epoch = len(datamodule.train_dataloader()) // hparams.accumulate_grad_batches // hparams.num_devices
    hparams.total_training_steps = hparams.max_steps if hparams.max_steps > 0 else hparams.training_steps_per_epoch * hparams.max_epochs
    
    # Load pretrained DINO model
    pretrain_model = DINOCLSModel.load_from_checkpoint(hparams.ckpt_path)
    vit = pretrain_model.encoders["all"].backbone
    hparams.in_features = vit.width
    hparams.epochs = hparams.max_epochs
    
    pprint(vars(hparams))
    
    # Create finetuner
    model = SSLFineTuner(backbones={"all": vit}, **vars(hparams))
    
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule, ckpt_path="last")


if __name__ == '__main__':
    parser = ArgumentParser(description="DINO Finetuning for downstream tasks.")
    
    # Task
    parser.add_argument("--eval_label", type=str, default="Stage")
    parser.add_argument("--downstream_dataset_name", type=str, default="mros")
    parser.add_argument("--use_which_backbone", type=str, default="all")
    parser.add_argument("--model_name", type=str, default="dino", help="Model name (only dino supported)")
    
    # Training
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=800)
    parser.add_argument("--val_batch_size", type=int, default=None)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--max_epochs", type=int, default=10)
    parser.add_argument("--max_steps", type=int, default=2500)
    parser.add_argument("--num_workers", type=int, default=32)
    parser.add_argument("--num_devices", type=int, default=1)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--precision", type=str, default="32-true")
    parser.add_argument("--train_data_pct", type=float, default=1.0)
    parser.add_argument("--n_train_samples", type=int, default=None)
    parser.add_argument("--val_data_pct", type=float, default=None)
    parser.add_argument("--sample_rate", type=int, default=64)
    
    # Model
    parser.add_argument("--patch_size_time", type=int, default=64)
    parser.add_argument("--patch_size_ch", type=int, default=4)
    parser.add_argument("--finetune_backbone", action="store_true")
    parser.add_argument("--use_mean_pool", action="store_true")
    
    # Scheduler
    parser.add_argument("--scheduler_type", type=str, default="cosine", choices=["cosine", "step", "constant"])
    parser.add_argument("--decay_epochs", type=int, default=10)
    parser.add_argument("--gamma", type=float, default=0.1)
    parser.add_argument("--final_lr", type=float, default=0)
    
    # Early stopping
    parser.add_argument("--early_stopping", action="store_true")
    parser.add_argument("--early_stopping_patience", type=int, default=10)
    
    # Data
    parser.add_argument("--data_source", type=str, default="auto")
    parser.add_argument("--include_datasets", type=str, nargs="*", default=None)
    
    # Checkpoint
    parser.add_argument("--ckpt_path", type=str, required=True)

    hparams = parser.parse_args()
    
    random.seed(hparams.seed)
    np.random.seed(hparams.seed)
    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)
    seed_everything(hparams.seed)
    main(hparams)
