"""
Main script for training treatment-images pair for cell painting

Mainly adopted from open_clip[1]

[1] https://github.com/mlfoundations/open_clip/blob/main/src/open_clip_train/main.py
[2] https://amsword.medium.com/gradient-backpropagation-with-torch-distributed
"""

import argparse
import glob
import os
import time

import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, set_seed
from graphium.config._loader import load_yaml_config
from graphium.data.collate import graphium_collate_fn
from hflayers import Hopfield
from torch import nn, optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from src import constants
from src.clip.clip import load_model
from src.clip.methods import clip, cloob, s2l_loss, sigmoid_loss
from src.helpler import (
    all_gather,
    compute_grad_norm,
    compute_param_norm,
    get_max_steps,
    get_metrics,
    print_args,
)
from src.mpnn.data.datamodule import MultitaskFromSmilesDataModuleBray
from src.scheduler import (
    const_lr,
    const_lr_cooldown,
    cosine_lr,
    get_cosine_with_hard_restarts_schedule_with_warmup,
)

torch.backends.cuda.matmul.allow_tf32 = True


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Training Contrastive Learning.")

    parser.add_argument(
        "--outdir", type=str, help="output parent directory", default=constants.OUT_DIR
    )
    parser.add_argument(
        "--split_label_dir",
        type=str,
        help="output parent directory",
        default=constants.SPLIT_LABEL_DIR,
    )
    parser.add_argument(
        "--split",
        type=int,
        help="index of dataset split file",
    )
    parser.add_argument(
        "--is_train",
        help="whether to use training index",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--opt_seed",
        type=int,
        help="random seed for model training",
        default=42,
    )
    parser.add_argument(
        "--wandb",
        help="whether to monitor model training with wandb",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--wandb_id",
        help="id for monitor training if laod from checkpoint",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--model_type",
        type=str,
        default="cell_clip",
        help=("Model types, e.g. cloome, cell_clip."),
    )
    parser.add_argument(
        "--embedding_name",
        type=str,
        default=None,
        help=("image embeddings, e.g. siglip@224, clip@224."),
    )
    parser.add_argument(
        "--img_dir",
        type=str,
        default=None,
        help=("Path to training input directory."),
    )
    parser.add_argument(
        "--input_dim",
        type=int,
        help="Dimension of input emebddings.",
        default=768,
    )
    parser.add_argument(
        "--molecule_path",
        type=str,
        default=None,
        help=("Path to molecule (text) data."),
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="bray2017",
        help=("dataset name, e.g. bray2017 or jumpcp"),
    )
    parser.add_argument(
        "--loss_type",
        type=str,
        default="clip",
        help=("Loss types, e.g. cloob, clip."),
    )
    parser.add_argument(
        "--pretrained",
        help="whether to use pretrained text encoder from CLIP.",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--resume",
        help="whether to use resume from previous training.",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--fine_tune_ckpt",
        help="path to ckpt for fine tuning.",
        default=None,
    )
    parser.add_argument(
        "--epochs",
        type=int,
        help="training epochs",
        default=50,
    )
    parser.add_argument(
        "--image_resolution_train",
        default=224,
        nargs="+",
        type=int,
        help="resolution for training set ",
    )
    parser.add_argument(
        "--image_resolution_val",
        default=224,
        nargs="+",
        type=int,
        help="resolution for validation set ",
    )
    parser.add_argument(
        "--val_subset_ratio",
        type=float,
        default=1.0,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help=(
            "Training batch size. When training with Accelerate, "
            "the batch size passed to the dataloader is the batch size per GPU"
        ),
        default=32,
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=5.0e-4,
    )
    parser.add_argument(
        "--beta1",
        type=float,
        default=0.9,
    )
    parser.add_argument(
        "--beta2",
        type=float,
        default=0.999,
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="cosine",
        help=(
            "LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown'"
            " (constant w/ cooldown). Default: cosine"
        ),
    )
    parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.")
    parser.add_argument(
        "--epochs-cooldown",
        type=int,
        default=None,
        help=(
            "When scheduler w/ cooldown used, "
            "perform cooldown from total_epochs - cooldown_epochs onwards."
        ),
    )
    parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
    parser.add_argument(
        "--warmup", type=int, default=1000, help="Number of steps to warmup for."
    )
    parser.add_argument(
        "--num_cycles", type=int, default=5, help="Number of cosine cycle during training."
    )

    # CLIP temperature
    parser.add_argument(
        "--init-inv-tau", type=float, default=14.3, help="Initial inverse tau."
    )
    parser.add_argument(
        "--learnable-inv-tau",
        default=False,
        action="store_true",
        help="Use a trainable logit scale for the nce loss.",
    )
    # Cloome hopfield params
    parser.add_argument(
        "--scale-hopfield", type=float, default=14.3, help="Scale for Hopfield retrieval."
    )
    parser.add_argument(
        "--learnable-scale-hopfield",
        default=False,
        action="store_true",
        help="Use a trainable logit scale for the Hopfield retrieval.",
    )
    parser.add_argument(
        "--ckpt_freq", type=int, default=1000, help="How often to save checkpoints."
    )
    parser.add_argument(
        "--log_freq", type=int, default=20, help="How often to check model training."
    )
    parser.add_argument(
        "--eval_freq", type=int, default=500, help="How often to evaluate model training."
    )
    parser.add_argument(
        "--keep_all_ckpts",
        help="whether to keep all the checkpoints",
        action="store_true",
        default=False,
    )
    return parser.parse_args()


def main(args):
    """Training scripts for contrastive learning."""
    set_seed(args.opt_seed)  # Seed for model optimization.

    ddp_kwargs = DistributedDataParallelKwargs(
        find_unused_parameters=True,
    )
    accelerator = Accelerator(
        step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]
    )

    accelerator.print("number of GPU available", torch.cuda.device_count())

    if accelerator.is_main_process:
        print_args(args)

    if args.wandb and accelerator.is_main_process:
        assert wandb is not None, "Please install wandb."
        accelerator.print("Starting wandb.")
        wandb.init(
            project=(
                f"Cell Painting {args.dataset}-{args.model_type}-{args.loss_type}-"
                f"{args.epochs}-{args.batch_size}-{args.embedding_name}-{args.lr}"
            ),
            dir="/XXXX-3/XXXX-4/XXXX-2/cell_painting/wandb",
            name=f"{args.model_type}-split_{args.split}-train_{args.is_train}",
            config=vars(args),
            id=args.wandb_id,
            resume="allow",
        )

    # Obtain training & evaluation data

    train_cfg = load_yaml_config(
        os.path.join(constants.OUT_DIR, "configs/graphium_configs/bray2017_train.yaml")
    )

    # Prepare training dataset

    train_datamodule = MultitaskFromSmilesDataModuleBray(
        **train_cfg["datamodule"]["args"],
    )
    train_datamodule.prepare_data()
    train_datamodule.setup("fit")

    num_workers = (
        4 * torch.cuda.device_count()
        if torch.get_num_threads() >= 4
        else torch.get_num_threads()
    )

    train_dataloader = DataLoader(
        train_datamodule.train_ds,
        batch_size=int(args.batch_size / accelerator.num_processes),
        collate_fn=graphium_collate_fn,
        num_workers=num_workers,
        pin_memory=True,
    )

    eval_cfg = load_yaml_config(
        os.path.join(constants.OUT_DIR, "configs/graphium_configs/bray2017_eval.yaml")
    )

    # Prepare validation dataset
    eval_datamodule = MultitaskFromSmilesDataModuleBray(
        **eval_cfg["datamodule"]["args"],
    )
    eval_datamodule.prepare_data()
    eval_datamodule.setup("fit")

    eval_dataloader = DataLoader(
        eval_datamodule.train_ds,
        batch_size=int(args.batch_size / accelerator.num_processes),
        collate_fn=graphium_collate_fn,
        num_workers=num_workers,
        pin_memory=True,
    )

    accelerator.print(
        "Initialize training and eval loader. Number of samples,",
        f"train:{len(train_datamodule.train_ds)}",
        f"eval:{len(eval_datamodule.train_ds)}.",
    )

    # Initlialize model.

    model = load_model(
        args.model_type,
        args.pretrained,
        args.image_resolution_train,
        vision_width=args.input_dim,
    )

    model_outdir = os.path.join(
        constants.OUT_DIR,
        "results",
        args.dataset,
        "models",
        args.model_type,
        (
            f"epochs_{args.epochs}_{args.img_dir.split('/')[-1]}_"
            f"{args.loss_type}_batch_size={args.batch_size}_"
            f"lr={args.lr}_pretrained={args.pretrained}_"
            f"cycle={args.num_cycles}_warmup={args.warmup}"
        ),
    )

    if accelerator.is_main_process:
        os.makedirs(model_outdir, exist_ok=True)

    exclude = (
        lambda n, p: p.ndim < 2
        or "bn" in n
        or "ln" in n
        or "bias" in n
        or "logit_scale" in n
    )

    def include(n, p):
        return not exclude(n, p)

    # include = lambda n, p: not exclude(n, p)

    named_parameters = list(model.named_parameters())
    gain_or_bias_params = [
        p for n, p in named_parameters if exclude(n, p) and p.requires_grad
    ]
    rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

    optimizer = optim.AdamW(
        [
            {"params": gain_or_bias_params, "weight_decay": 0.0},
            {"params": rest_params, "weight_decay": args.wd},
        ],
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        eps=args.eps,
    )

    # Adjust steps for DDP.
    steps_per_epoch = len(train_dataloader) // accelerator.num_processes
    total_steps = steps_per_epoch * args.epochs

    if args.lr_scheduler == "cosine":
        scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
    elif args.lr_scheduler == "const":
        scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps)
    elif args.lr_scheduler == "const-cooldown":
        assert (
            args.epochs_cooldown is not None
        ), "Please specify the number of cooldown epochs for this lr schedule."
        cooldown_steps = steps_per_epoch * args.epochs_cooldown
        scheduler = const_lr_cooldown(
            optimizer,
            args.lr,
            args.warmup,
            total_steps,
            cooldown_steps,
            args.lr_cooldown_power,
            args.lr_cooldown_end,
        )
    elif args.lr_scheduler == "cosine-restarts":
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer,
            warmup=args.warmup,
            num_cycles=args.num_cycles,
            num_training_steps=total_steps,
        )
    elif args.lr_scheduler == "one_cycle":
        scheduler = OneCycleLR(
            optimizer,
            max_lr=0.1,
            total_steps=total_steps,
            pct_start=0.1,
            anneal_strategy="cos",
        )

    else:
        raise ValueError(
            f"Unknown scheduler, {args.lr_scheduler}. "
            f"Available options are: cosine, const, const-cooldown."
        )

    epoch = 0
    steps = 0
    total_steps_time = 0

    if args.fine_tune_ckpt:

        try:
            ckpt = torch.load(args.fine_tune_ckpt, map_location="cpu")
            accelerator.print(f"Loading pretrained checkpoint at {args.fine_tune_ckpt} ")

        except RuntimeError:
            accelerator.print(
                f"Pretrained check point at {args.fine_tune_ckpt} does not exist."
            )

    elif args.resume:
        # Check if there is an existing checkpoint to resume from. This occurs when
        # model runs are interrupted (e.g., exceeding job time limit).

        existing_steps = get_max_steps(model_outdir)

        if existing_steps is not None:
            ckpt_path = os.path.join(model_outdir, f"ckpt_steps_{existing_steps:0>8}.pt")

            try:
                ckpt = torch.load(ckpt_path, map_location="cpu")
                steps = ckpt["steps"]
                epoch = ckpt["epoch"]
                total_steps_time = ckpt["total_steps_time"]
                model.load_state_dict(ckpt["model"])
                optimizer.load_state_dict(ckpt["optimizer"])
                scheduler.load_state_dict(ckpt["scheduler"])

                accelerator.print(
                    f"Resuming checkpoint at epoch {epoch}; steps {steps} from {ckpt_path}"
                )

            except RuntimeError:
                accelerator.print(f"Check point at {ckpt_path} does not exist.")

    if args.model_type in ["cloome", "cloome_old"]:
        hopfield_layer = Hopfield(
            input_size=512,
            scaling=args.scale_hopfield,
            normalize_hopfield_space=False,
            normalize_hopfield_space_affine=False,
            normalize_pattern_projection=False,
            normalize_pattern_projection_affine=False,
            normalize_state_pattern=False,
            normalize_state_pattern_affine=False,
            normalize_stored_pattern=False,
            normalize_stored_pattern_affine=False,
            state_pattern_as_static=True,
            pattern_projection_as_static=True,
            stored_pattern_as_static=True,
            disable_out_projection=True,
            num_heads=1,
            dropout=False,
        )
        model, hopfield_layer = accelerator.prepare(model, hopfield_layer)
    else:
        model = accelerator.prepare(model)

    loss_fct_img = nn.CrossEntropyLoss()
    loss_fct_tx = nn.CrossEntropyLoss()

    optimizer, scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
        optimizer, scheduler, train_dataloader, eval_dataloader
    )

    progress_bar = tqdm(
        range(total_steps),
        initial=steps,
        desc="Steps",
        disable=not accelerator.is_main_process,
    )
    steps_start_time = time.time()
    while steps < total_steps:
        for _, batch in enumerate(train_dataloader):

            optimizer.zero_grad()

            model.train()

            images, treatments = batch["images"], batch["features"]

            if args.loss_type in ["sigclip", "s2l"]:
                img_features, mol_features, logit_scale, bias = model(images, treatments)
            else:
                img_features, mol_features, logit_scale = model(images, treatments)

            if accelerator.use_distributed:
                # Gather all image and text features from all GPUs
                all_image_features = all_gather(img_features)
                all_mol_features = all_gather(mol_features)
            else:
                all_image_features = img_features
                all_mol_features = mol_features

            if args.loss_type == "clip":
                loss = clip(
                    all_image_features,
                    all_mol_features,
                    logit_scale,
                    loss_fct_img,
                    loss_fct_tx,
                )
            elif args.loss_type == "cloob":
                loss = cloob(
                    all_image_features,
                    all_mol_features,
                    logit_scale.exp(),
                    hopfield_layer,
                )
            elif args.loss_type == "sigclip":
                loss = sigmoid_loss(
                    all_image_features,
                    all_mol_features,
                    logit_scale,
                    bias,
                )
            elif args.loss_type == "s2l":
                loss = s2l_loss(
                    all_image_features,
                    all_mol_features,
                    logit_scale,
                    bias,
                )
            else:
                raise ValueError(f"Loss type {args.loss_type} undefined")

            accelerator.backward(loss)

            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 20.0)

            optimizer.step()
            scheduler.step()

            # m.logit_inv_tau.data = torch.clamp(m.logit_inv_tau.data, 0, 4.6052)
            m = model.module if accelerator.use_distributed else model

            with torch.no_grad():
                m.logit_scale.data = torch.clamp(m.logit_scale.data, 0, 4.6052)

            if accelerator.sync_gradients:
                steps += 1
                progress_bar.update(1)

                if steps % steps_per_epoch == 0:
                    epoch += 1

                if steps % args.log_freq == 0 and accelerator.is_main_process:

                    steps_time = time.time() - steps_start_time
                    total_steps_time += steps_time

                    # Check gradient norm and parameter norm.
                    grad_norm = compute_grad_norm(accelerator, model)
                    param_norm = compute_param_norm(accelerator, model)

                    info = f"Step[{steps}/{total_steps}]"
                    info += f", steps_time: {steps_time:.3f}"
                    info += f", loss: {loss.detach().cpu().item():.5f}"
                    info += f", temperature: {m.logit_scale.data.exp():.6f}"
                    info += f", gradient norms: {grad_norm:.5f}"
                    info += f", parameters norms: {param_norm:.5f}"
                    info += f", lr: { scheduler.get_last_lr()[0]:.6f}"
                    accelerator.print(info, flush=True)

                    if args.wandb:
                        wandb.log(
                            {
                                "step": steps,
                                "loss": loss.detach().cpu().item(),
                                "temperature": m.logit_scale.data.exp(),
                                "steps_time": steps_time,
                                "gradient norm": grad_norm,
                                "parameter norm": param_norm,
                                "lr": scheduler.get_last_lr()[0],
                            }
                        )

                    steps_start_time = time.time()

                if (
                    steps % args.ckpt_freq == 0 or steps == total_steps
                ) and accelerator.is_main_process:
                    if not args.keep_all_ckpts:
                        pattern = os.path.join(model_outdir, "ckpt_steps_*.pt")
                        for filename in glob.glob(pattern):
                            os.remove(filename)

                    torch.save(
                        {
                            "model": accelerator.get_state_dict(model),
                            "optimizer": optimizer.state_dict(),
                            "scheduler": scheduler.state_dict(),
                            "steps": steps,
                            "epoch": epoch,
                            "total_steps_time": total_steps_time,
                        },
                        os.path.join(model_outdir, f"ckpt_steps_{steps:0>8}.pt"),
                    )
                    accelerator.print(f"Checkpoint saved at step: {steps} epoch:{epoch}")
                    steps_start_time = time.time()

                if steps % args.eval_freq == 0 or steps == 1 or steps == total_steps:

                    model.eval()
                    accelerator.print("Evaluation with retrieval task..")

                    with torch.no_grad():
                        all_eval_image_features = []
                        all_eval_mol_features = []

                        for _, batch in enumerate(eval_dataloader):

                            images, treatments = batch["images"], batch["features"]

                            with accelerator.autocast():
                                if args.loss_type in ["sigclip", "s2l"]:
                                    (
                                        img_features,
                                        mol_features,
                                        logit_scale,
                                        bias,
                                    ) = model(images, treatments)
                                else:
                                    img_features, mol_features, logit_scale = model(
                                        images, treatments
                                    )

                            all_eval_image_features.append(img_features)
                            all_eval_mol_features.append(mol_features)

                        all_eval_image_features = torch.cat(all_eval_image_features)
                        all_eval_mol_features = torch.cat(all_eval_mol_features)

                        if accelerator.use_distributed:
                            all_eval_image_features = accelerator.gather_for_metrics(
                                all_eval_image_features
                            )
                            all_eval_mol_features = accelerator.gather_for_metrics(
                                all_eval_mol_features
                            )

                        if args.loss_type == "clip":
                            val_loss = clip(
                                all_image_features,
                                all_mol_features,
                                logit_scale,
                                loss_fct_img,
                                loss_fct_tx,
                            )
                        elif args.loss_type == "sigclip":
                            val_loss = sigmoid_loss(
                                all_image_features,
                                all_mol_features,
                                logit_scale,
                                bias,
                            )
                        elif args.loss_type == "cloob":
                            val_loss = cloob(
                                all_image_features,
                                all_mol_features,
                                logit_scale.exp(),
                                hopfield_layer,
                            )
                        elif args.loss_type == "s2l":
                            val_loss = s2l_loss(
                                all_image_features,
                                all_mol_features,
                                logit_scale,
                                bias,
                            )

                        if accelerator.is_main_process:
                            # Evaluation only in the main process
                            metrics = get_metrics(
                                all_eval_image_features, all_eval_mol_features
                            )
                            modalities = {"image_to_text", "text_to_image"}
                            for name in modalities:
                                info = (
                                    f"Evaluation: {name} retrieval at epoch:{epoch}, "
                                    f"Steps: {steps}/{total_steps}, "
                                    f"Val loss: {val_loss:.4f}, "
                                    f"mean_rank: {metrics[f'{name}_mean_rank']:.1f}, "
                                    f"median_rank: {metrics[f'{name}_median_rank']:.1f}, "
                                    f"R@1: {metrics[f'{name}_R@1']:.4f}, "
                                    f"R@5: {metrics[f'{name}_R@5']:.4f}, "
                                    f"R@10: {metrics[f'{name}_R@10']:.4f}."
                                )
                                accelerator.print(info, flush=True)

                                if args.wandb:
                                    wandb.log(
                                        {
                                            "step": steps,
                                            "val loss": val_loss,
                                            f"{name}_mean_rank": metrics[
                                                f"{name}_mean_rank"
                                            ],
                                            f"{name}_median_rank": metrics[
                                                f"{name}_median_rank"
                                            ],
                                            f"{name}_R@{1}": metrics[f"{name}_R@1"],
                                            f"{name}_R@{5}": metrics[f"{name}_R@5"],
                                            f"{name}_R@{10}": metrics[f"{name}_R@10"],
                                        }
                                    )
                    steps_start_time = time.time()

                if steps == total_steps:
                    break


if __name__ == "__main__":
    args = parse_args()
    main(args)
