# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import inspect
import os
import random
import string
from collections import OrderedDict

import wandb
import hydra
import torch
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch import Callback
from omegaconf import DictConfig, OmegaConf, ListConfig
from solo.args.pretrain import parse_cfg
from solo.data.classification_dataloader import prepare_data as prepare_data_classification
from solo.data.pretrain_dataloader import (
    FullTransformPipeline,
    NCropAugmentation,
    build_transform_pipeline,
    prepare_dataloader,
    prepare_datasets,
)
from solo.methods import METHODS
from solo.utils.auto_resumer import AutoResumer
from solo.utils.checkpointer import Checkpointer
from solo.utils.misc import make_contiguous, omegaconf_select

try:
    from solo.data.dali_dataloader import PretrainDALIDataModule, build_transform_pipeline_dali
except ImportError:
    _dali_avaliable = False
else:
    _dali_avaliable = True

try:
    from solo.utils.auto_umap import AutoUMAP
except (ImportError, RuntimeError):
    _umap_available = False
else:
    _umap_available = True

from torchinfo import summary
from fvcore.nn import FlopCountAnalysis

def generate_random_name(length=8):
    """Generate a random string of uppercase letters and digits."""
    characters = string.ascii_uppercase + string.digits
    return ''.join(random.choices(characters, k=length))

def exclude_bias_and_norm(p):
    """Exclude bias and norm weights from weight decay."""
    return p.ndim == 1

def compute_model_FLOPs(model, dataset_name):
    if "cifar" in dataset_name:
        input_size=(1, 3, 32, 32)
    elif "imagenet" in dataset_name:
        input_size=(1, 3, 224, 224)
    elif "CelebA" in dataset_name:
        input_size=(1, 3, 128, 128)
    elif "3dshapes" in dataset_name:
        input_size=(1, 3, 64, 64)
    else:
        raise ValueError(f"FLOPs computation not configured for dataset: {dataset_name}")
    
    # get parameter counts
    stats = summary(model, input_size=input_size)

    # get FLOPs counts
    fake_input = (torch.randn(input_size).cuda(),)

    with torch.no_grad():
        model.eval()
        model_flops_per_image = FlopCountAnalysis(model, fake_input).total()
        print(f'model_flops_per_image: {model_flops_per_image / 1e6} M')

    # reset model to training mode
    model.train()

    return model_flops_per_image

class FLOPsLogger(Callback):
    def __init__(self, model_flops_per_image, num_images_per_train_dataset, cfg):
        super().__init__()

        ssl_method = cfg.method
        if ssl_method in ["vicreg", "radialbyol", "byol", "wmse", "radialvicreg", "vicreg_e2mc", "radialmocov3", "radialswav", "swav", "simclr", "radialsimclr"]:
            # these methods primarily compute loss between two main views.
            num_augs = 2
        elif ssl_method == "nat":
            # NAT processes all crops and aligns them to targets
            num_augs = cfg.data.num_large_crops + cfg.data.num_small_crops
            if num_augs == 0: # Should not happen with typical configs
                 print("Warning: num_augs is 0 for NAT. FLOPs per epoch might be underestimated.")
                 num_augs = 1 # Avoid division by zero or zero FLOPs if config is weird
        else:
            raise NotImplementedError(f"implement FLOPs computation for {ssl_method} method")
        
        self.flops_per_epoch = model_flops_per_image * num_images_per_train_dataset * num_augs

    def on_train_epoch_end(self, trainer, pl_module):
        cumulative_flops = self.flops_per_epoch * (trainer.current_epoch + 1)
        pl_module.log("cumulative_flops", float(cumulative_flops), sync_dist=True)

@hydra.main(version_base="1.2")
def main(cfg: DictConfig):
    # hydra doesn't allow us to add new keys for "safety"
    # set_struct(..., False) disables this behavior and allows us to add more parameters
    # without making the user specify every single thing about the model
    OmegaConf.set_struct(cfg, False)
    cfg = parse_cfg(cfg)

    # unique string for this run
    rand_str_for_wandb_name = generate_random_name()

    # Instantiate the model.
    model = METHODS[cfg.method](cfg)

    # Set up CPT for lightning resume by correctly naming things
    cpt_ckpt_path = omegaconf_select(cfg, "cpt_checkpoint_path", None)

    temp_ckpt_for_resume = None

    if cpt_ckpt_path: # continued pretraining mode 
        print(f"INFO: Continued Pretraining (CPT) mode enabled. Loading from foreign checkpoint: {cpt_ckpt_path}")

        # HACK: We need to instantiate a scheduler to load its state, which requires a `trainer`
        # object with `estimated_stepping_batches`. We create a dummy trainer for this.

        # 1. Get dataset size to calculate estimated_stepping_batches.
        # This duplicates some data pipeline setup from later in the script.
        print("INFO: CPT: Efficiently calculating dataset size to configure scheduler...")
        train_path = cfg.data.train_path
        # Avoids instantiating a new Dataset object, which is slow for ImageFolder.
        # This is a fast way to get number of files for an ImageFolder-like structure.
        if not os.path.isdir(train_path):
            raise FileNotFoundError(f"Training data path not found: {train_path}")

        # Walk through all subdirectories and count image files
        image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
        dataset_size = sum(
            1
            for dirpath, _, filenames in os.walk(train_path)
            for filename in filenames
            if os.path.splitext(filename)[1].lower() in image_extensions
        )

        if cfg.data.fraction > 0 and cfg.data.fraction < 1.0:
            dataset_size = int(dataset_size * cfg.data.fraction)

        print(f"INFO: CPT: Found {dataset_size} training examples.")

        # 2. Create a dummy trainer to get optimizers and schedulers.
        # The scheduler needs `estimated_stepping_batches`, which is a read-only
        # property of the Trainer in modern Pytorch Lightning. We create a mock
        # object with this property instead of a full Trainer.
        num_gpus_for_calc = (
            len(cfg.devices) if isinstance(cfg.devices, (list, ListConfig)) else cfg.devices
        )
        num_gpus_for_calc = max(1, num_gpus_for_calc)

        accumulate_grad_batches = (
            cfg.accumulate_grad_batches if cfg.accumulate_grad_batches else 1
        )
        estimated_stepping_batches = (
            dataset_size // (cfg.optimizer.batch_size * num_gpus_for_calc * cfg.num_nodes)
        ) // accumulate_grad_batches

        from argparse import Namespace

        dummy_trainer = Namespace(estimated_stepping_batches=estimated_stepping_batches)

        # 3. Attach dummy trainer to get optimizers and schedulers.
        model.trainer = dummy_trainer
        opt_cfg = model.configure_optimizers()
        model.trainer = None  # Detach trainer

        if isinstance(opt_cfg, tuple) and len(opt_cfg) == 2:
            optimizers, schedulers = opt_cfg
        else:
            optimizers = opt_cfg
            schedulers = []

        foreign_ckpt = torch.load(cpt_ckpt_path, map_location="cpu", weights_only=False)

        # Prepare a lightning-compatible checkpoint
        lightning_ckpt = {}

        # 1. Load model state
        if "model" in foreign_ckpt:
            model_sd = foreign_ckpt["model"]
        elif "state_dict" in foreign_ckpt:
            model_sd = foreign_ckpt["state_dict"]
        else:
            raise ValueError("Foreign checkpoint does not have 'model' or 'state_dict' key.")

        # Handle key mismatches between official MoCo-V3 checkpoints and solo-learn's structure.
        new_sd = OrderedDict()
        for k, v in model_sd.items():
            # First, strip the 'module.' prefix if it exists.
            if k.startswith("module."):
                k = k[len("module.") :]

            # Remap keys based on the expected model structure in solo-learn.
            # MoCo-V3 checkpoints often combine the backbone and projector in 'base_encoder'.
            if k.startswith("base_encoder."):
                if "fc." in k:
                    # It's a projector weight, remap 'base_encoder.fc.*' to 'projector.*'
                    new_k = k.replace("base_encoder.fc.", "projector.")
                else:
                    # It's a backbone weight, remap 'base_encoder.*' to 'backbone.*'
                    new_k = k.replace("base_encoder.", "backbone.")
            elif k.startswith("momentum_encoder."):
                if "fc." in k:
                    # Remap momentum projector weights
                    new_k = k.replace("momentum_encoder.fc.", "momentum_projector.")
                else:
                    # Remap momentum backbone weights
                    new_k = k.replace("momentum_encoder.", "momentum_backbone.")
            elif k.startswith("predictor."):
                # Predictor keys should match
                new_k = k.replace("predictor.", "predictor.")
            else:
                # Keep other keys (if any) as they are.
                new_k = k

            new_sd[new_k] = v

        # Load state dict into model directly, with strict=False
        incompatible_keys = model.load_state_dict(new_sd, strict=False)
        print("INFO: CPT: Loaded model state_dict with strict=False.")
        if incompatible_keys.missing_keys:
            print(f"INFO: CPT: Missing keys: {incompatible_keys.missing_keys}")
        if incompatible_keys.unexpected_keys:
            print(f"INFO: CPT: Unexpected keys: {incompatible_keys.unexpected_keys}")

        # Include the current model state_dict so Lightning can reload without missing keys.
        lightning_ckpt["state_dict"] = model.state_dict()

        # 2. Start with a fresh optimizer state for CPT. We deliberately do NOT copy the state from the foreign checkpoint because learning-rate schedules, momentum buffers, etc. should restart when we continue pre-training with different hyper-parameters.
        lightning_ckpt["optimizer_states"] = []
        print("INFO: CPT: Initialising optimizer from scratch for continued pretraining.")

        # 3. Start with a fresh LR-scheduler state as well for similar reasons
        lightning_ckpt["lr_schedulers"] = []
        print("INFO: CPT: Initialising LR scheduler from scratch for continued pretraining.")

        # 4. Load epoch and compute global_step for a valid resume state
        if "epoch" in foreign_ckpt:
            # Lightning expects the epoch number that has just been COMPLETED.
            # e.g., if epoch 999 is done, ckpt['epoch'] is 999. trainer.current_epoch will be 1000.
            completed_epoch = foreign_ckpt["epoch"]
            lightning_ckpt["epoch"] = completed_epoch
            # global_step should be the total number of batches seen so far.
            # After 'completed_epoch' (0-indexed) are done, 'completed_epoch + 1' full epochs have passed.
            lightning_ckpt["global_step"] = (completed_epoch + 1) * estimated_stepping_batches
            print(
                f"INFO: CPT: Resuming from epoch {completed_epoch}. "
                f"Calculated global_step: {lightning_ckpt['global_step']}."
            )
        else:
            raise ValueError("Foreign checkpoint must have 'epoch' key for continued pretraining.")

        # Add other necessary lightning keys for a valid checkpoint
        lightning_ckpt["pytorch-lightning_version"] = "1.7.4"
        lightning_ckpt["callbacks"] = {}

        # Re-create the minimal loop state needed for the trainer to resume correctly.
        # This is the critical piece that was missing.
        lightning_ckpt["loops"] = {
            "fit_loop": {
                "state_dict": {},
                "epoch_progress": {
                    "total": {
                        "ready": completed_epoch + 1,
                        "started": completed_epoch + 1,
                        "processed": completed_epoch + 1,
                        "completed": completed_epoch,
                    },
                    "current": {
                        "ready": completed_epoch + 1,
                        "started": completed_epoch + 1,
                        "processed": completed_epoch + 1,
                        "completed": completed_epoch,
                    },
                },
                "epoch_loop.state_dict": {"_batches_that_stepped": lightning_ckpt["global_step"]},
            }
        }

        # Ensure optimizer_states and lr_schedulers keys exist to satisfy Lightning resume
        lightning_ckpt.setdefault("optimizer_states", [])
        lightning_ckpt.setdefault("lr_schedulers", [])

        # Save the temporary lightning checkpoint with a unique name.
        # We combine the wandb run id (if available) and the process ID to ensure uniqueness
        # and traceability, avoiding race conditions in parallel CPT jobs.
        wandb_id = omegaconf_select(cfg, "wandb.id", rand_str_for_wandb_name)
        pid = os.getpid()
        temp_ckpt_for_resume = f"cpt_resume_temp_{wandb_id}_{pid}.ckpt"
        torch.save(lightning_ckpt, temp_ckpt_for_resume)

        # Point the main resume logic to our temporary checkpoint
        cfg.resume_from_checkpoint = temp_ckpt_for_resume
        # Disable solo-learn's auto-resume to avoid conflicts
        if cfg.auto_resume.enabled:
            print("WARNING: CPT: Disabling auto_resume to prioritize CPT checkpoint.")
            cfg.auto_resume.enabled = False

    seed_everything(cfg.seed)

    assert cfg.method in METHODS, f"Choose from {METHODS.keys()}"

    if cfg.data.num_large_crops != 2:
        assert cfg.method in ["wmse", "mae"]

    make_contiguous(model)
    # can provide up to ~20% speed up
    if not cfg.performance.disable_channel_last:
        model = model.to(memory_format=torch.channels_last)

    # compute FLOPs
    # Disable FLOPs computation for spawn strategies to avoid pickling issues
    model_flops_per_image = None
    if "spawn" not in cfg.strategy:
        model_flops_per_image = compute_model_FLOPs(model, cfg.data.dataset)
    else:
        print("Skipping FLOPs computation for spawn-based DDP strategy to avoid pickling errors.")

    # validation dataloader for when it is available
    if cfg.data.dataset == "custom" and (cfg.data.no_labels or cfg.data.val_path is None):
        val_loader = None
    elif cfg.data.dataset in ["imagenet100", "imagenet"] and cfg.data.val_path is None:
        val_loader = None
    else:
        if cfg.data.format == "dali":
            val_data_format = "image_folder"
        else:
            val_data_format = cfg.data.format

        _, val_loader = prepare_data_classification(
            cfg.data.dataset,
            train_data_path=cfg.data.train_path,
            val_data_path=cfg.data.val_path,
            data_format=val_data_format,
            batch_size=cfg.optimizer.batch_size,
            num_workers=cfg.data.num_workers,
        )

    # pretrain dataloader
    if cfg.data.format == "dali":
        assert (
            _dali_avaliable
        ), "Dali is not currently avaiable, please install it first with pip3 install .[dali]."
        pipelines = []
        for aug_cfg in cfg.augmentations:
            pipelines.append(
                NCropAugmentation(
                    build_transform_pipeline_dali(
                        cfg.data.dataset, aug_cfg, dali_device=cfg.dali.device
                    ),
                    aug_cfg.num_crops,
                )
            )
        transform = FullTransformPipeline(pipelines)

        dali_datamodule = PretrainDALIDataModule(
            dataset=cfg.data.dataset,
            train_data_path=cfg.data.train_path,
            transforms=transform,
            num_large_crops=cfg.data.num_large_crops,
            num_small_crops=cfg.data.num_small_crops,
            num_workers=cfg.data.num_workers,
            batch_size=cfg.optimizer.batch_size,
            no_labels=cfg.data.no_labels,
            data_fraction=cfg.data.fraction,
            dali_device=cfg.dali.device,
            encode_indexes_into_labels=cfg.dali.encode_indexes_into_labels,
        )
        dali_datamodule.val_dataloader = lambda: val_loader
    else:
        pipelines = []
        for aug_cfg in cfg.augmentations:
            pipelines.append(
                NCropAugmentation(
                    build_transform_pipeline(cfg.data.dataset, aug_cfg), aug_cfg.num_crops
                )
            )
        transform = FullTransformPipeline(pipelines)

        if cfg.debug_augmentations:
            print("Transforms:")
            print(transform)

        train_dataset = prepare_datasets(
            cfg.data.dataset,
            transform,
            train_data_path=cfg.data.train_path,
            data_format=cfg.data.format,
            no_labels=cfg.data.no_labels,
            data_fraction=cfg.data.fraction,
        )
        train_loader = prepare_dataloader(
            train_dataset, batch_size=cfg.optimizer.batch_size, num_workers=cfg.data.num_workers
        )

    # 1.7 will deprecate resume_from_checkpoint, but for the moment
    # the argument is the same, but we need to pass it as ckpt_path to trainer.fit
    ckpt_path, wandb_run_id = None, None
    if cfg.auto_resume.enabled and cfg.resume_from_checkpoint is None:
        auto_resumer = AutoResumer(
            checkpoint_dir=os.path.join(cfg.checkpoint.dir, cfg.method),
            max_hours=cfg.auto_resume.max_hours,
        )
        resume_from_checkpoint, wandb_run_id = auto_resumer.find_checkpoint(cfg)
        if resume_from_checkpoint is not None:
            print(
                "Resuming from previous checkpoint that matches specifications:",
                f"'{resume_from_checkpoint}'",
            )
            ckpt_path = resume_from_checkpoint
    elif cfg.resume_from_checkpoint is not None:
        ckpt_path = cfg.resume_from_checkpoint
        del cfg.resume_from_checkpoint

    callbacks = []

    # FLOPs Logging
    # Disable FLOPs logging for spawn strategies
    if model_flops_per_image is not None and "spawn" not in cfg.strategy:
        flops_logger = FLOPsLogger(model_flops_per_image, len(train_dataset), cfg)
        callbacks.append(flops_logger)


    if cfg.checkpoint.enabled:
        ckpt = Checkpointer(
            cfg,
            logdir=os.path.join(cfg.checkpoint.dir, cfg.method),
            frequency=cfg.checkpoint.frequency,
            keep_prev=cfg.checkpoint.keep_prev,
        )
        callbacks.append(ckpt)

    if omegaconf_select(cfg, "auto_umap.enabled", False):
        assert (
            _umap_available
        ), "UMAP is not currently avaiable, please install it first with [umap]."
        auto_umap = AutoUMAP(
            cfg.name,
            logdir=os.path.join(cfg.auto_umap.dir, cfg.method),
            frequency=cfg.auto_umap.frequency,
        )
        callbacks.append(auto_umap)

    # wandb logging
    if cfg.wandb.enabled:
        # CHANGE: Determine the wandb run name correctly
        # Priority:
        # 1. cfg.wandb.name (set by CLI override ++wandb.name=...)
        # 2. cfg.name (from YAML) + random string (fallback if CLI override for wandb.name is not used)
        
        # Check if cfg.wandb.name is set - if it is use that instead of the default name
        final_wandb_name = omegaconf_select(cfg, "wandb.name", None) # Check if overridden
        if not final_wandb_name: # If not set by override (e.g., ++wandb.name=...) or is empty
            final_wandb_name = cfg.name + "_" + rand_str_for_wandb_name

        # If a wandb run id is passed, this will be used for resuming.
        if wandb_run_id is None:
            wandb_run_id = omegaconf_select(cfg, "wandb.id", None)

        # get notes from config for wandb - to help get better descriptions for each run
        wandb_notes = omegaconf_select(cfg, "wandb.notes", None)

        wandb_logger = WandbLogger(
            name=final_wandb_name,
            project=cfg.wandb.project,
            entity=cfg.wandb.entity,
            offline=cfg.wandb.offline,
            resume="allow" if wandb_run_id else None,
            id=wandb_run_id,
            notes=wandb_notes,
            settings=wandb.Settings(init_timeout=600),
        )
        # Disable wandb.watch for spawn strategies to avoid pickling issues
        if "spawn" not in cfg.strategy:
            wandb_logger.watch(model, log="gradients", log_freq=100)
        else:
            print("Skipping wandb.watch for spawn-based DDP strategy to avoid pickling errors.")
        wandb_logger.log_hyperparams(OmegaConf.to_container(cfg))

        # lr logging
        lr_monitor = LearningRateMonitor(logging_interval="step")
        callbacks.append(lr_monitor)

    trainer_kwargs = OmegaConf.to_container(cfg)
    # we only want to pass in valid Trainer args, the rest may be user specific
    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = {name: trainer_kwargs[name] for name in valid_kwargs if name in trainer_kwargs}
    trainer_kwargs.update(
        {
            "logger": wandb_logger if cfg.wandb.enabled else None,
            "callbacks": callbacks,
            "enable_checkpointing": False,
            # Ensure DDPSpawnStrategy is correctly instantiated if strategy is "ddp_spawn"
            "strategy": DDPStrategy(find_unused_parameters=True)
            if cfg.strategy == "ddp"
            else "ddp_spawn"
            if cfg.strategy == "ddp_spawn"
            else cfg.strategy,
        }
    )
    
    # Let's add a breakpoint here to inspect variables before training starts
    # import ipdb; ipdb.set_trace()

    trainer = Trainer(**trainer_kwargs)

    # Always run validation before training
    if val_loader:
        print("INFO: Running a validation epoch before starting training.")
        # dali datamodule is not available when using dali, so we need to check for it
        if cfg.data.format == "dali" and "dali_datamodule" in locals():
            trainer.validate(model, datamodule=dali_datamodule)
        else:
            trainer.validate(model, val_loader)
    else:
        print("WARNING: No validation loader is available. Skipping initial validation.")

    if cfg.data.format == "dali":
        trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule)
    else:
        trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)

    if temp_ckpt_for_resume and os.path.exists(temp_ckpt_for_resume):
        os.remove(temp_ckpt_for_resume)
        print(f"INFO: CPT: Removed temporary checkpoint file: {temp_ckpt_for_resume}")


if __name__ == "__main__":
    main()
