import functools
import logging
import os
import pathlib
from typing import Dict, Optional, cast

import hydra
import torch
import torch.distributed as dist
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, open_dict
from torchinfo import summary

import wandb
from crps_retrofitting.optim.optimizer_utils import setup_optimizer_and_scheduler
from crps_retrofitting.utils.utils import load_common_weights


def setup_distributed():
    if not dist.is_initialized():
        try:
            dist.init_process_group(
                backend="cpu:gloo,cuda:nccl",  # both CPU + GPU backends
                init_method="env://",
                world_size=int(os.environ["WORLD_SIZE"]),
                rank=int(os.environ["RANK"]),
            )
        except:
            pass


def has_additional_params(module_cfg):
    print(
        "Has additional params",
        any(
            getattr(module_cfg, attr, 0) != 0
            for attr in ("noise_cond_dim", "norm_cond_dim")
        ),
    )
    return any(
        getattr(module_cfg, attr, 0) != 0
        for attr in ("noise_cond_dim", "norm_cond_dim")
    )


from crps_retrofitting.data import MixedWellDataModule
from crps_retrofitting.data.well_to_multi_transformer import (
    ChannelsFirstWithTimeFormatter,
)
from crps_retrofitting.optim.distributed_shampoo.shampoo_types import (
    FSDPShampooConfig,
    HSDPShampooConfig,
)
from crps_retrofitting.optim.distributed_shampoo.utils.shampoo_fsdp_utils import (
    compile_fsdp_parameter_metadata,
)
from crps_retrofitting.trainer.checkpoints import CheckPointLoader
from crps_retrofitting.trainer.training import Trainer
from crps_retrofitting.utils.distribution_utils import (
    configure_distribution,
    distribute_model,
)
from crps_retrofitting.utils.experiment_utils import (
    align_checkpoint_with_field_to_index_map,
    configure_experiment,
)

logger = logging.getLogger("crps_retrofitting")

# Retrieve configuration for hydra
CONFIG_DIR = pathlib.Path(__file__).parent / "configs"
CONFIG_NAME = "config"
CONFIG_PATH = CONFIG_DIR / f"{CONFIG_NAME}.yaml"
assert CONFIG_PATH.is_file(), f"Configuration {CONFIG_PATH} is not an existing file."
logger.info(f"Run training script for {CONFIG_PATH}")


def train(
    cfg: DictConfig,
    experiment_name: str,
    experiment_folder: str,
    viz_folder: str,
    old_field_index_map: Optional[Dict] = None,
    is_distributed: bool = False,
    world_size: int = 1,
    rank: int = 0,
    local_rank: int = 0,
    device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
    debug: bool = False,
):
    """Instantiate the different objects required for training and run the training loop."""

    logger.info(f"Instantiate datamodule {cfg.data.wandb_data_name}")
    datamodule: MixedWellDataModule = instantiate(
        cfg.data.module_parameters,
        world_size=world_size,
        rank=rank,
        data_workers=cfg.data_workers,
        max_rollout_steps=cfg.trainer.max_rollout_steps,
        well_base_path=cfg.data.well_base_path,
        field_index_map_override=cfg.data.get("field_index_map_override", {}),
        transform=cfg.data.get("transform", None),
    )
    field_to_index_map = datamodule.train_dataset.field_to_index_map
    if "noise" in field_to_index_map:
        with open_dict(cfg):
            cfg.model.noise_field_idx = field_to_index_map["noise"]
    # TODO - currently enforcing MPP format, but should allow for other types
    # Retrieve the number of fields used in training
    # from the mapping of field to index
    total_input_fields = max(field_to_index_map.values()) + 1

    logger.info(
        f"Instantiate model {cfg.model._target_}",
    )
    model: torch.nn.Module = instantiate(
        cfg.model,
        n_states=total_input_fields,
    )
    # Initialise model_info
    model_info, model_checkpoint = None, None
    finetuning_with_additional_params = False
    if (
        hasattr(cfg.checkpoint, "coalesced_checkpoint_path")
        and cfg.checkpoint.coalesced_checkpoint_path is not None
        and (cfg.finetune or cfg.validation_mode)
    ):
        # If we've specified a coalesced path, we want to load model weights only using
        # standard pytorch load
        logger.info(
            f"Loading coalesced checkpoint {cfg.checkpoint.coalesced_checkpoint_path}"
        )
        checkpoint = torch.load(
            cfg.checkpoint.coalesced_checkpoint_path, map_location="cpu"
        )
        # Load the model weights
        model_checkpoint = checkpoint["app"]["model"]
        model_checkpoint = align_checkpoint_with_field_to_index_map(
            old_state_dict=model_checkpoint,
            new_state_dict=model.state_dict(),
            old_field_to_index_map=old_field_index_map,
            new_field_to_index_map=field_to_index_map,
        )

        if cfg.validation_mode and hasattr(cfg, "finetuning_mods"):
            if hasattr(model, "add_ft_options"):
                model.add_ft_options(cfg.finetuning_mods)

        # If we are adding new layers, allow strict=False
        # Check if we are adding additional noise parameters
        finetuning_with_additional_params = any(
            has_additional_params(getattr(cfg.model, part, {}))
            for part in ("encoder", "processor", "decoder")
        ) or hasattr(cfg.model, "noise_dim")

        if finetuning_with_additional_params:
            # This acts like model.load_state_dict(model_checkpoint, strict=False) but it gives verbose output about which weights were loaded/missing
            model_info = load_common_weights(
                model, model_checkpoint, strict=True, verbose=False
            )

            # Calculate parameter counts for each category
            for category, param_names in model_info.items():
                if isinstance(param_names, list):
                    total_params = 0
                    for param_name in param_names:
                        if hasattr(model, "get_parameter"):
                            # For distributed models
                            try:
                                param = model.get_parameter(param_name)
                                total_params += param.numel()
                            except:
                                # Fallback to state_dict lookup
                                state_dict = model.state_dict()
                                if param_name in state_dict:
                                    total_params += state_dict[param_name].numel()
                        else:
                            # For regular models
                            state_dict = model.state_dict()
                            if param_name in state_dict:
                                total_params += state_dict[param_name].numel()

                    logger.info(
                        f"{category}: {len(param_names)} parameters, {total_params:,} total elements"
                    )
        else:
            model.load_state_dict(model_checkpoint, strict=True)

    # Finetuning changes - technically useable without FT too
    if hasattr(cfg, "finetuning_mods") and not cfg.validation_mode:
        if hasattr(model, "add_ft_options"):
            model.add_ft_options(cfg.finetuning_mods)

    if rank == 0:
        summary(model, depth=5)

    logger.info(
        f"Assigning distribution strategy: {cfg.distribution.distribution_type}"
    )
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{int(local_rank)}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    model = model.to(device)
    if debug and rank == 0 and cfg.logger.wandb:
        wandb.watch(model, log="all", log_freq=5)
    model = distribute_model(model, cfg, device_mesh)

    # Set start epoch to 0 before potential retrieval from checkpoint
    start_epoch = 1
    last_epoch = -1  # Default for Pytorch
    val_loss = torch.tensor(float("inf"))

    # If we are resuming a model that has in turn been finetuned with additional parameters,
    # we load the very first model weights to create optimizer parameter groups correctly
    initial_model_info = None
    if hasattr(cfg, "initial_model_checkpoint"):
        checkpoint = torch.load(cfg.initial_model_checkpoint, map_location="cpu")
        # Load the model weights
        initial_model_checkpoint = checkpoint["app"]["model"]
        initial_model_info = load_common_weights(
            model, initial_model_checkpoint, strict=True, verbose=False
        )

    # Setup optimizer and scheduler with staged learning support
    optimizer, lr_scheduler = setup_optimizer_and_scheduler(
        cfg=cfg,
        model=model,
        device_mesh=device_mesh,
        model_info=model_info if model_info is not None else initial_model_info,
        last_epoch=last_epoch,
    )

    logger.info(f"Instantiate checkpointer {cfg.checkpoint._target_}")
    checkpointer: CheckPointLoader = instantiate(cfg.checkpoint, rank=rank)

    if hasattr(checkpointer, "load_checkpoint_path"):
        load_checkpoint_path = checkpointer.load_checkpoint_path
        if load_checkpoint_path is not None and os.path.exists(load_checkpoint_path):
            # If this is a finetuning load, just load model weights
            if (
                cfg.finetune or cfg.validation_mode
            ) and load_checkpoint_path != checkpointer.last_checkpoint:
                logger.info(f"Finetuning from checkpoint {load_checkpoint_path}")
                checkpointer.load(model, strict=True)
            # Otherwise this is a resume load
            else:
                logger.info(f"Resume from checkpoint {load_checkpoint_path}")
                epoch, val_loss = checkpointer.load(model, optimizer, strict=False)
                # Ensure initial_lr is set for each parameter group
                for param_group in optimizer.param_groups:
                    if "initial_lr" not in param_group:
                        param_group["initial_lr"] = cfg.optimizer.lr

                logger.info(
                    f"Resume from epoch {epoch} with validation loss {val_loss}"
                )
                start_epoch = 1 if epoch is None else epoch + 1
                last_epoch = (
                    start_epoch - 1
                )  # Set last_epoch to the last completed epoch
                if lr_scheduler is not None and epoch is not None:
                    lr_scheduler.step(last_epoch + 1)

    if debug:
        import pickle

        with open("model_state_run.pkl", "wb") as f:
            pickle.dump(model.state_dict(), f)
    # Update the config with the newly generated field-to-index map for resuming/knowing what was there
    with open_dict(cfg):
        cfg.data.field_index_map_override = field_to_index_map
    if rank == 0:
        logger.info(f"Final configuration:\n{OmegaConf.to_yaml(cfg)}")
    logger.info(f"Instantiate trainer {cfg.trainer._target_}")
    trainer: Trainer = instantiate(
        cfg.trainer,
        experiment_name=experiment_name,
        viz_folder=viz_folder,
        model=model,
        datamodule=datamodule,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        checkpointer=checkpointer,
        device=device,
        device_mesh=device_mesh,
        distribution_type=cfg.distribution.distribution_type,
        rank=rank,
        world_size=world_size,
        formatter=ChannelsFirstWithTimeFormatter,  # TODO change this to function of model
        wandb_logging=cfg.logger.wandb,
        start_epoch=start_epoch,
        start_val_loss=val_loss,
    )
    if cfg.validation_mode:
        trainer.validate()
    else:
        # Save config to directory folder
        if rank == 0:
            with open(
                pathlib.Path(experiment_folder) / "extended_config.yaml", "w"
            ) as f:
                OmegaConf.save(cfg, f)
        trainer.train()


@hydra.main(
    version_base=None, config_path=str(CONFIG_DIR), config_name=str(CONFIG_NAME)
)
def main(cfg: DictConfig):
    # Torch optimization settings
    torch.set_float32_matmul_precision("high")  # Use TF32 when supported
    torch.backends.cudnn.allow_tf32 = True
    # Retrieve multiple processes context to setup DDP
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    is_distributed = (
        cfg.distribution.distribution_type.upper() != "LOCAL" and world_size > 1
    )
    setup_distributed()
    # Since configure_experiment uses distributed logic, distribution must be set up first
    device_mesh = configure_distribution(cfg)
    (
        cfg,
        experiment_name,
        experiment_folder,
        checkpoint_folder,
        artifact_folder,
        viz_folder,
        old_field_index_map,
    ) = configure_experiment(cfg, rank, is_distributed)

    logger.info(f"Run experiment {experiment_name}")
    logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
    # Initiate wandb logging

    # Make sure we're logging the true batch size
    config_for_wandb = cast(Dict, OmegaConf.to_container(cfg, resolve=True))
    config_for_wandb["world_size"] = world_size
    # Global batch size is microbatch size * number of GPUs * gradient accumulation steps
    config_for_wandb["global_batch_size"] = (
        cfg.data.module_parameters.batch_size * world_size
    ) * cfg.trainer.grad_acc_steps
    if rank == 0 and cfg.logger.wandb:
        wandb.init(
            project=cfg.logger.wandb_project_name,
            group=f"{cfg.data.wandb_data_name}",
            config=config_for_wandb,
            name=experiment_name,
        )
    train(
        cfg,
        experiment_name,
        experiment_folder,
        viz_folder,
        old_field_index_map,
        is_distributed,
        world_size,
        rank,
        local_rank,
        device_mesh=device_mesh,
    )
    if rank == 0 and cfg.logger.wandb:
        wandb.finish()


if __name__ == "__main__":
    main()
