import lightning as pl
from lightning.pytorch.loggers import WandbLogger
from birdsnap.birdsnap_datasets import TorchSpatialDataModule
from birdsnap.birdsnap_locationencoder3d import BirdsnapLocationEncoder3D
from birdsnap.birdsnap_test_snapshots import check_or_update_snapshot
from config.experiment_params import process_args
# from config.birdsnap import args_birdsnap
from utils import set_default_if_unset
from utils.utils import get_project_root, save_model_parameters
import logging
logger = logging.getLogger(__name__)

# Remove any existing handlers (optional, to avoid duplicate logs)
if logger.hasHandlers():
    logger.handlers.clear()

# Create a handler with a custom format for this module
handler = logging.StreamHandler()
formatter = logging.Formatter(
    "%(asctime)s BIRDSNAP-EXPERIMENT-RUNNER %(levelname)s %(name)s: %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

import wandb
import torch
import yaml
from lightning.pytorch import LightningDataModule
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

# %% [markdown]
# ### Other
from birdsnap.birdsnap_datasets import TorchSpatialDataModule

def load_and_harmonize_hparams(args) -> tuple:
    """
    Loads hyperparameters from a YAML file and harmonizes them with the provided arguments.

    Args:
        args (Namespace): The command-line arguments.

    Returns:
        tuple: (args, harmonized hyperparameters)
    """
    with open(get_project_root() / args.hparams) as f:
        hparams = yaml.safe_load(f)

    # "harmonize" hparams and args
    hparams['optimizer'] = {
        'lr': args.lr,
        'wd': args.wd
    }
    hparams['legendre_polys'] = args.locationencoder_args["legendre_polys"]
    hparams['harmonics_calculation'] = args.locationencoder_args["harmonics_calculation"]
    hparams["presence_only_loss"] = args.presence_only_loss
    hparams['patience'] = args.patience
    hparams['regression'] = args.regression
    hparams["max_epochs"] = args.max_epochs
    hparams["min_radius"] = args.min_radius

    hparams = set_default_if_unset(hparams, "max_radius", 360)

    args.locationencoder_args["hparams"] = hparams

    return args, hparams

def fit_model(
    args,
    locationencoder: torch.nn.Module,
    datamodule: LightningDataModule,
    wandb_logger: WandbLogger = None,
    resume_checkpoint: str = None
) -> tuple:
    """
    Fits the model using the provided arguments, location encoder, and data module.

    Args:
        args (Namespace): The command-line arguments.
        locationencoder (torch.nn.Module): The location encoder model.
        datamodule (LightningDataModule): The data module for the dataset.
        logger (WandbLogger, optional): Logger for experiment tracking.
        resume_checkpoint (str, optional): Path to a checkpoint to resume training from.

    Returns:
        tuple: The trained locationencoder, trainer, and data module.
    """
    # fix random seeds
    pl.seed_everything(args.seed, workers=True)
    
    callbacks = [
        EarlyStopping(monitor="val_loss", mode="min", patience=args.patience),
    ]

    # Use GPU if it is available
    accelerator = args.accelerator
    devices = 1
    if args.gpus == -1 or args.gpus == [-1]:
        devices = 'auto'
    else:
        devices = args.gpus

    if torch.cuda.is_available():
        accelerator = 'gpu'

    logger.info(f"Using GPUs: {devices}")
    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        log_every_n_steps=5,
        callbacks=callbacks,
        accelerator=accelerator,
        devices=devices,
        logger=wandb_logger,
        precision=64,
        num_sanity_val_steps=0,
    )

    datamodule.setup()
    trainer.fit(
        model=locationencoder,
        train_dataloaders=datamodule.train_dataloader(),
        val_dataloaders=datamodule.val_dataloader(),
        ckpt_path=resume_checkpoint
    )
    return trainer, locationencoder, datamodule

class BirdsnapExperimentRunner:
    """
    Encapsulates a single experiment run for the birdsnap location encoder.
    """
    def __init__(
        self,
        args_birdsnap,
        sub_experiment: dict,
        log_dir: str,
        tag: str,
        project: str = "location-encoder",
        dataset: str = "birdsnap",
    ):
        self.args_birdsnap = args_birdsnap
        self.sub_experiment = sub_experiment
        self.log_dir = log_dir
        self.project = project
        self.dataset = dataset
        self.tag = tag

    def run(self):
        # Harmonize and set up arguments
        args_birdsnap_harmonized, _ = load_and_harmonize_hparams(self.args_birdsnap)
        args_birdsnap_harmonized.seed = self.sub_experiment.get("seed", 0)
        pl.seed_everything(args_birdsnap_harmonized.seed, workers=True)

        args_birdsnap_harmonized.lr = 1e-3
        args_birdsnap_harmonized.max_epochs = self.sub_experiment.get("max_epochs", 100)
        args_birdsnap_harmonized.patience = self.sub_experiment.get("patience", 0)
        args_birdsnap_harmonized.locationencoder_args["arch_name"] = self.sub_experiment.get("arch_name", "baseline_arch_v1")
        args_birdsnap_harmonized.locationencoder_args["ortho_weight"] = self.sub_experiment["ortho_weight"]
        args_birdsnap_harmonized.locationencoder_args["ortho_weight_space"] = self.sub_experiment["ortho_weight_space"]
        args_birdsnap_harmonized.locationencoder_args["ortho_weight_time"] = self.sub_experiment.get("ortho_weight_time", 0)
        args_birdsnap_harmonized.locationencoder_args["normality_flag"] = False
        args_birdsnap_harmonized.locationencoder_args["ortho_exponent"] = self.sub_experiment.get("ortho_exponent", 1)
        args_birdsnap_harmonized.locationencoder_args["legendre_polys"] = 20
        args_birdsnap_harmonized.locationencoder_args["combination_type"] = self.sub_experiment["combination_type"]
        args_birdsnap_harmonized.locationencoder_args["time_embedding_dim"] = self.sub_experiment["time_embedding_dim"]
        args_birdsnap_harmonized.locationencoder_args["time_embedding_type"] = self.sub_experiment["time_embedding_type"]
        args_birdsnap_harmonized.datamodule_args["subset_fraction"] = self.sub_experiment["subset_fraction"]

        locationencoder_args, datamodule_args = process_args(
            args_birdsnap_harmonized.locationencoder_args,
            args_birdsnap_harmonized.datamodule_args
        )

        # Create wandb logger
        wandb_logger = WandbLogger(
            project=self.project,
            name=f"{self.dataset}-{self.tag}-{args_birdsnap_harmonized.seed}",
            config=args_birdsnap_harmonized,
            save_dir=self.log_dir,
            log_model=False,
        )

        # Add experiment_name_root as a tag
        wandb_logger.experiment.tags = list(getattr(wandb_logger.experiment, "tags", [])) + [self.tag]

        birdsnap_datamodule = TorchSpatialDataModule(
            dataset=self.dataset,
            batch_size=args_birdsnap_harmonized.datamodule_args['batch_size'],
            num_workers=args_birdsnap_harmonized.datamodule_args['num_workers'],
            subset_fraction=datamodule_args["subset_fraction"],
        )

        del locationencoder_args["arch_name"]

        locationencoder = BirdsnapLocationEncoder3D(
            **locationencoder_args,
            sub_experiment=self.sub_experiment,
            args_birdsnap_harmonized=args_birdsnap_harmonized
        )

        trainer, locationencoder, _ = fit_model(
            args_birdsnap_harmonized, locationencoder, birdsnap_datamodule, wandb_logger=wandb_logger
        )

        # Log model architecture summary
        model_summary = str(locationencoder)
        wandb_logger.experiment.config.update({"model_summary": model_summary})

        # Test and log predictions
        trainer.test(
            model=locationencoder,
            dataloaders=birdsnap_datamodule.test_dataloader(),
        )

        # Create a unique snapshot file for each sub_experiment
        snapshot_file = (
            f"birdsnap_model_output_"
            f"seed{args_birdsnap_harmonized.seed}_"
            f"combtype_{self.sub_experiment['combination_type']}_"
            f"time_emb_{self.sub_experiment['time_embedding_type']}_"
            f"time_emb_dim_{self.sub_experiment['time_embedding_dim']}_"
            f"ortho_{self.sub_experiment['ortho_weight']}_"
            f"ortho_space_{self.sub_experiment.get('ortho_weight_space', 0)}_"
            f"ortho_time_{self.sub_experiment.get('ortho_weight_time', 0)}_"
            f"subset_{self.sub_experiment['subset_fraction']}_"
        )

        check_or_update_snapshot(
            model=locationencoder,
            config=self.sub_experiment,
            snapshot_file="artifacts/" + snapshot_file
        )

        save_model_parameters(trainer, locationencoder, wandb_logger)

        # stop wandb
        wandb_logger.experiment.finish()

        # check_or_update_snapshot(
        #     model=locationencoder,
        #     config=self.sub_experiment,
        #     snapshot_file="artifacts/" + snapshot_file
        # )

        # stop wandb
        wandb_logger.experiment.finish()
