from inat.InatLocationEncoder import InatLocationEncoder3D, InatVisionLocationEncoder3D
from config.experiment_params import process_args
from utils.utils import get_project_root
from utils import set_default_if_unset
import yaml
import torch

import lightning as pl

import sys
from pathlib import Path

# Add the TorchSpatial folder to the Python module search path
torchspatial_path = Path("./location-embeddings-3d/dependencies/TorchSpatial")
sys.path.append(str(torchspatial_path / "main"))

from datasets import (
  load_dataset,
)

from inat.datamodule import TorchSpatialDataModule, TorchSpatialDataModule_VisionLocation

class InatExperimentRunner:
    """
    Encapsulates a single experiment run for the iNat location encoder.
    """
    def __init__(
        self,
        args_inat,
        sub_experiment: dict,
        log_dir: str,
        tag: str,
        project: str = "location-encoder",
        dataset: str = "iNat2018",
    ):
        self.args_inat = args_inat
        self.sub_experiment = sub_experiment
        self.log_dir = log_dir
        self.project = project
        self.dataset = dataset
        self.tag = tag

    def load_and_harmonize_hparams(self, args):
        with open(get_project_root() / args.hparams) as f:
            hparams = yaml.safe_load(f)
        hparams['optimizer'] = {
            'lr': args.lr,
            'wd': args.wd
        }
        hparams['legendre_polys'] = args.locationencoder_args["legendre_polys"]
        hparams['harmonics_calculation'] = args.locationencoder_args["harmonics_calculation"]
        hparams["presence_only_loss"] = args.presence_only_loss
        hparams['patience'] = args.patience
        hparams['regression'] = args.regression
        hparams["max_epochs"] = args.max_epochs
        hparams["min_radius"] = args.min_radius
        hparams = set_default_if_unset(hparams, "max_radius", 360)
        args.locationencoder_args["hparams"] = hparams
        return args, hparams

    def fit_model(self, args_inat, locationencoder, inat2018datamodule, logger=None, resume_checkpoint=None):
        from lightning.pytorch.callbacks.early_stopping import EarlyStopping
        callbacks = [
            EarlyStopping(monitor="val_loss", mode="min", patience=args_inat.patience)
        ]
        accelerator = args_inat.accelerator
        devices = 1
        if args_inat.gpus == -1 or args_inat.gpus == [-1]:
            devices = 'auto'
        else:
            devices = args_inat.gpus
        if torch.cuda.is_available():
            accelerator = 'gpu'
        trainer = pl.Trainer(
            max_epochs=args_inat.max_epochs,
            log_every_n_steps=5,
            callbacks=callbacks,
            accelerator=accelerator,
            devices=devices,
            logger=logger,
            precision=64,
            num_sanity_val_steps=0,
        )
        trainer.fit(
            model=locationencoder,
            train_dataloaders=inat2018datamodule.train_dataloader(),
            val_dataloaders=inat2018datamodule.val_dataloader(),
            ckpt_path=resume_checkpoint
        )
        return locationencoder, trainer, inat2018datamodule

    def run(self):
        args_inat_harmonized, _ = self.load_and_harmonize_hparams(self.args_inat)
        args_inat_harmonized.seed = self.sub_experiment.get("seed", 0)
        pl.seed_everything(args_inat_harmonized.seed, workers=True)
        args_inat_harmonized.max_epochs = self.sub_experiment.get("max_epochs", 1)
        args_inat_harmonized.patience = self.sub_experiment.get("patience", 1)
        args_inat_harmonized.locationencoder_args["arch_name"] = self.sub_experiment.get("arch_name", "baseline_arch_v1")
        args_inat_harmonized.locationencoder_args["ortho_weight"] = self.sub_experiment["ortho_weight"]
        args_inat_harmonized.locationencoder_args["ortho_weight_space"] = self.sub_experiment.get("ortho_weight_space", 0)
        args_inat_harmonized.locationencoder_args["ortho_weight_time"] = self.sub_experiment.get("ortho_weight_time", 0)
        args_inat_harmonized.locationencoder_args["normality_flag"] = False
        args_inat_harmonized.locationencoder_args["ortho_exponent"] = self.sub_experiment.get("ortho_exponent", 1)
        args_inat_harmonized.locationencoder_args["legendre_polys"] = self.sub_experiment.get("legendre_polys", 10)
        args_inat_harmonized.locationencoder_args["combination_type"] = self.sub_experiment["combination_type"]
        args_inat_harmonized.locationencoder_args["time_embedding_dim"] = self.sub_experiment["time_embedding_dim"]
        args_inat_harmonized.locationencoder_args["time_embedding_type"] = self.sub_experiment["time_embedding_type"]
        # normality_flag for locationencoder
        args_inat_harmonized.locationencoder_args["normality_flag"] = self.sub_experiment.get("normality_correction", False)
        args_inat_harmonized.datamodule_args["num_workers"] = 8
        args_inat_harmonized.datamodule_args["subset_fraction"] = 1
        args_inat_harmonized.datamodule_args["batch_size"] = self.sub_experiment.get("batch_size", 2000)
        args_inat_harmonized.datamodule_args["num_classes"] = self.sub_experiment["variable_cut_off"]

        vision_location_training = self.sub_experiment.get("vision_location_training", False)
        locationencoder_args, _ = process_args(
            args_inat_harmonized.locationencoder_args,
            args_inat_harmonized.datamodule_args
        )

        from lightning.pytorch.loggers import WandbLogger
        wandb_logger = WandbLogger(
            project=self.project,
            name=f"{self.dataset}-{self.tag}-{args_inat_harmonized.seed}",
            config=args_inat_harmonized,
            save_dir=self.log_dir,
            log_model=False,
        )
        wandb_logger.experiment.tags = list(getattr(wandb_logger.experiment, "tags", [])) + [self.tag]

        if vision_location_training:
            inat_datamodule = TorchSpatialDataModule_VisionLocation(
                dataset=args_inat_harmonized.datamodule_args['dataset_name'],
                batch_size=args_inat_harmonized.datamodule_args['batch_size'],
                num_workers=args_inat_harmonized.datamodule_args['num_workers'],
                subset_fraction=args_inat_harmonized.datamodule_args['subset_fraction'],
            )

            LocationEncoderClass = InatVisionLocationEncoder3D
        else:
            # DataModuleClass = TorchSpatialDataModule
            inat_datamodule = TorchSpatialDataModule(
                dataset=args_inat_harmonized.datamodule_args['dataset_name'],
                batch_size=args_inat_harmonized.datamodule_args['batch_size'],
                num_workers=args_inat_harmonized.datamodule_args['num_workers'],
                subset_fraction=args_inat_harmonized.datamodule_args['subset_fraction'],
                variable_cut_off=self.sub_experiment["variable_cut_off"],
            )

            LocationEncoderClass = InatLocationEncoder3D

        del locationencoder_args["arch_name"]

        inat_datamodule.setup()

        locationencoder = LocationEncoderClass(**locationencoder_args)

        locationencoder, trainer, _ = self.fit_model(
            args_inat_harmonized, locationencoder, inat_datamodule, logger=wandb_logger
        )

        model_summary = str(locationencoder)
        wandb_logger.experiment.config.update({"model_summary": model_summary})

        trainer.test(
            model=locationencoder,
            dataloaders=inat_datamodule.test_dataloader(),
        )

        wandb_logger.experiment.finish()
