import multiprocessing
import os
import shutil
import signal
import subprocess
import tempfile
import time as time
import uuid
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, Optional

import jax
import jax.numpy as jnp
import numpy as np
import requests
import wandb
from omegaconf import DictConfig, OmegaConf, open_dict

from jadex.base.base_model import BaseModel
from jadex.base.base_state import BaseState
from jadex.base.registrable import register_all
from jadex.data.dataloader import JaxDataLoader, create_dataloader
from jadex.data.dataloader.jax_sampler import SampleBuffer
from jadex.data.datasets.base_dataset import BaseDataset
from jadex.global_configs.constants import JADEX_CHECKPOINT_DIR, PREEMPT_DIR
from jadex.utils import submit_job, update_ablation_progress
from jadex.utils.printing import print_blue, print_green, print_yellow
from jadex.utils.wandb_sweep import create_sweep, merge_sweep_config

DEFAULT_SUBMIT_TO_API = False
DEFAULT_SAVE_PREEMPT = False


class ModelTrainer(ABC):
    def __init__(
        self,
        cfg: DictConfig,
        model: BaseModel,
        train_dataset: BaseDataset,
        train_dataloader: JaxDataLoader,
        test_dataset: BaseDataset,
        test_dataloader: Optional[JaxDataLoader] = None,
        fid: Optional[Callable] = None,
        keep_feature_idxs: Optional[Dict[str, np.ndarray]] = None,
        ctx=None,
    ):
        self.cfg = cfg
        self.model = model
        self.train_dataset = train_dataset
        self.train_dataloader = train_dataloader
        self.test_dataset = test_dataset
        self.test_dataloader = test_dataloader
        self.fid = fid
        self.keep_feature_idxs = keep_feature_idxs
        self._preempt_state = None
        self._ctx = ctx

        if cfg.job.checkpoint_frequency_nsteps > 0:
            self.checkpoint_file = self._create_checkpoint_file(prefix=cfg.wandb.project)
            assert self.checkpoint_file.suffix == ".h5"
            self.ckpt_idx = 0

    @staticmethod
    def get_checkpoint_dir() -> Path:
        return JADEX_CHECKPOINT_DIR

    @classmethod
    @abstractmethod
    def create_trainer_kwargs_and_state(cls, cfg, train_dataset):
        raise NotImplementedError

    def log_expensive(self, state: BaseState, batch: Dict, metrics: Dict, val=False):
        return {}

    def extract_features_from_batch(self, p_batch):
        return p_batch

    def get_aux_metrics(state, batch, metrics, val=False):
        return {}

    @property
    def exclude_metrics(self):
        return ["train_x_hats", "val_x_hats"]

    @classmethod
    def submit(cls, cfg):
        register_all()

        if cfg.get("cluster") and cfg.cluster.get("submitted", False):
            # If already submitted to the cluster, then start training to prevent recursion
            cls._run_training(cfg)
            return

        create_sweep_cfg = cfg.get("create_sweep")
        run_sweep_cfg = cfg.get("run_sweep")

        # NOTE: This gets called after we run the agent using cls.submit_sweep. This means we may be inside of cluster!
        is_running_sweep = cls._is_running_sweep()

        num_active = (
            int(create_sweep_cfg is not None) + int(run_sweep_cfg is not None) + int(is_running_sweep)
        )
        assert num_active < 2, "only one of these can be set at a time!"

        if not is_running_sweep:
            with open_dict(cfg):
                cfg.wandb = DictConfig(cfg.get("wandb") or {"mode": "disabled"})
                cfg.wandb.project = cfg.wandb.get("project", default_value=cls.get_project_name(cfg))

        if is_running_sweep:
            cfg = merge_sweep_config(cfg)
        elif run_sweep_cfg is not None:
            cls.submit_sweep(cfg)
            return
        elif create_sweep_cfg is not None:
            create_sweep(cfg)
            return

        start_checkpoint = cfg.job.get("start_checkpoint")
        if start_checkpoint is not None:
            ckpt_file = cls.get_checkpoint_dir() / start_checkpoint
            assert ckpt_file.exists(), f"No checkpoint {start_checkpoint} found!"
            ckpt_cfg: DictConfig = BaseState.load_cfg(ckpt_file)
            # override checkpoint config with specified keys
            override_keys = cfg.job.checkpoint_overrides + [
                "job.start_checkpoint",
                "job.start_checkpoint_idx",
            ]
            for key in override_keys:
                value = OmegaConf.select(cfg, key)
                OmegaConf.update(ckpt_cfg, key, value, merge=False)
            cfg = ckpt_cfg

        cluster_cfg = cfg.get("cluster") if not is_running_sweep else None
        cls._submit_training(cluster_cfg, cfg)

    @classmethod
    def _submit_training(cls, cluster_cfg, cfg):
        # Create a uuid if one is not provided
        if cfg.job.get("uuid", None) is None:
            with open_dict(cfg):
                cfg.job.uuid = str(uuid.uuid4())

        if cluster_cfg is None or cluster_cfg.name == "venv":
            submit_job(cls._run_training, cfg)
        else:
            raise ValueError(f"{cluster_cfg.name} unrecognized!")

    @classmethod
    def _run_training(cls, cfg):
        register_all()

        ctx = multiprocessing.get_context("forkserver")

        trainer_kwargs, state = cls.create_trainer_kwargs_and_state(cfg, ctx)

        if cls._is_running_sweep():
            wandb.config.update(OmegaConf.to_container(cfg))
        else:
            wandb.init(**cfg.wandb, config=OmegaConf.to_container(cfg))

        print_blue(OmegaConf.to_yaml(cfg))
        state.print_opt_params()

        if cfg.job.get("update_ablation_progress", False):
            update_ablation_progress(cfg.job.ablation_uuid, "RUNNING")

        start_checkpoint = cfg.job.get("start_checkpoint")
        if start_checkpoint is not None:
            state = state.load_checkpoint(
                cls.get_checkpoint_dir() / start_checkpoint, cfg.job.start_checkpoint_idx
            )

        if cfg.job.get("uuid") is not None:
            preempt_ckpt = PREEMPT_DIR / f"{cfg.job.uuid}.h5"
            if preempt_ckpt.exists():
                print("Loading preempted checkopint!")
                state = state.load_checkpoint(preempt_ckpt, -1)

        train_dataloader = create_dataloader(
            cfg=cfg,
            mode="train",
            dataset=trainer_kwargs["train_dataset"],
            sample_buffer=SampleBuffer.create_from_state(state.sample_buffer_data, ctx),
            ctx=ctx,
        )

        test_dataloader = None
        if trainer_kwargs.get("test_dataset", None) is not None:
            test_dataloader = create_dataloader(
                cfg=cfg, mode="test", dataset=trainer_kwargs["test_dataset"], ctx=ctx
            )

        trainer = cls(
            cfg=cfg,
            train_dataloader=train_dataloader,
            test_dataloader=test_dataloader,
            ctx=ctx,
            **trainer_kwargs,
        )

        trainer._train(state)

        for child_proc in ctx.active_children():
            child_proc.terminate()

    def _train(self, state: BaseState):
        self._register_signal_handlers()

        if state.step >= self.cfg.job.total_nsteps:
            print_yellow(
                "Run has already completed! To continue training, increase the job.total_nsteps "
                + "(and ensure checkpoint_overrides contains job.total_nsteps)"
            )
            return

        last_steps = {
            "print": state.step,
            "wandb_inexpensive": state.step,
            "wandb_expensive": state.step,
            "validation": state.step,
            "checkpoint": state.step,
        }

        should_run = dict([(key, False) for key in last_steps.keys()])

        p_update_step = jax.pmap(self.update_model, in_axes=(None, 0), out_axes=(None, 0), axis_name="batch")

        last_step = state.step
        total_time_min = state.total_time_min

        for i, p_full_batch in enumerate(self.train_dataloader):
            p_batch = self.extract_features_from_batch(p_full_batch)
            start_time = time.time()

            state, p_train_metrics = p_update_step(state, p_batch)

            # Ensure computation is finished before logging and measuring time
            p_train_metrics = jax.tree_util.tree_map(lambda x: x.block_until_ready(), p_train_metrics)
            time_per_iter = time.time() - start_time

            # JAX traces the first 2/3 iterations, so only accumulate time after the tracing is done
            if i > 2:
                total_time_min += time_per_iter / 60.0

            state = state.replace(sample_buffer_data=self.train_dataloader.synchronize())

            # For gradient accumulation, state.step is only incremented every k minibatches
            if state.step > last_step:
                nsteps_per_min = 60.0 / time_per_iter
                p_train_metrics["nsteps_per_min"] = nsteps_per_min
                p_train_metrics["total_time_min"] = total_time_min

                state = state.replace(total_time_min=total_time_min)

                self._update_should_run(should_run, last_steps, state.step)
                self.post_update(state, p_batch, p_train_metrics, should_run, nsteps_per_min)

                if state.step >= self.cfg.job.total_nsteps:
                    # save final checkpoint before exiting
                    print_green("Run finished!")
                    self.unlink_preempt_checkpoint()  # remove the preempt checkpoint
                    # if we haven't saved the last checkpoint yet, then save it
                    if self.cfg.job.checkpoint_frequency_nsteps > 0 and not should_run["checkpoint"]:
                        print_green("Saving final checkpoint...")
                        state.save_checkpoint(self.checkpoint_file, self.ckpt_idx)
                    return

                # Update last_times for all triggered events
                for event in should_run:
                    if should_run[event]:
                        last_steps[event] = state.step

            last_step = state.step

    def _update_should_run(self, should_run, last_steps, cur_step):
        """Modifies should_run in-place"""
        for metric_name in last_steps.keys():
            freq_key = f"{metric_name}_frequency_nsteps"
            metric_freq_nsteps = self.cfg.job.get(freq_key, 0)
            if metric_freq_nsteps > 0:
                should_run[metric_name] = cur_step - last_steps[metric_name] >= metric_freq_nsteps

    def post_update(self, state: BaseState, p_batch, p_train_metrics, should_run, nsteps_per_min):
        wandb_metrics = {}

        if should_run["print"]:
            loss = jnp.mean(p_train_metrics["loss"]).item()
            step_str = f"{state.step:07d}"
            if state.internal_step != state.step:
                step_str = f"{state.step:07d}/{state.internal_step:07d}"
            print(f"step {step_str} ({nsteps_per_min:.2f}/min) | loss {loss:.4f}")

        if should_run["validation"]:
            val_batch, val_metrics = self.run_validation(state)
            wandb_metrics.update(jax.tree.map(jnp.mean, val_metrics))
            wandb_metrics.update(self.log_expensive(state, val_batch, val_metrics, val=True))
            wandb_metrics.update(self.get_aux_metrics(state, val_batch, val_metrics, val=True))

        if should_run["wandb_inexpensive"] or should_run["wandb_expensive"]:
            # Save preempt state at wandb intervals (for cleaner reloading)
            self._preempt_state = state  # used by signal handlers if job is interrupted
            train_batch = jax.tree.map(jnp.concatenate, p_batch)
            train_metrics = jax.tree.map(lambda x: jnp.concatenate(jnp.atleast_2d(x)), p_train_metrics)

        if should_run["wandb_inexpensive"]:
            wandb_metrics.update(self.get_aux_metrics(state, train_batch, train_metrics))
            wandb_metrics.update(jax.tree.map(jnp.mean, train_metrics))

        if should_run["wandb_expensive"]:
            wandb_metrics.update(self.log_expensive(state, train_batch, train_metrics))
            self.save_preempt_checkpoint()

        if should_run["checkpoint"]:
            state.save_checkpoint(self.checkpoint_file, self.ckpt_idx)
            print_green(f"##### Saved checkpoint {self.ckpt_idx} #####")
            self.ckpt_idx += 1

        if wandb_metrics:
            wandb_metrics["step"] = state.step
            for exclude_metric in self.exclude_metrics:
                wandb_metrics.pop(exclude_metric, None)
            wandb.log(wandb_metrics)

    def run_validation(self, state: BaseState):
        raise NotImplementedError

    def update_model(self, state: BaseState, batch: jnp.ndarray):
        """Apply gradient updates to model parameters."""
        args_key, grad_key = jax.random.split(state.rng_key)
        loss_args = self.model.get_loss_args(state, batch, args_key)
        state, metrics = state.perform_gradient_update(loss_args, grad_key)
        return state, metrics

    @staticmethod
    def submit_sweep(cfg):
        """
        Submit a sweep using SLURM, if available.
        NOTE: This is only supported for "venv" (not apptainer)
        """
        submit_job(cfg.get("cluster"), wandb.agent, **cfg.run_sweep)

    @staticmethod
    def _is_running_sweep():
        return os.getenv("WANDB_SWEEP_ID") is not None

    @classmethod
    def _create_checkpoint_file(cls, prefix):
        try:
            run_name = wandb.run.name
            # when wandb is disabled, the run name starts with "dummy"
            if not run_name or run_name.startswith("dummy"):
                raise ValueError
            suffix = run_name
        except:
            suffix = "".join(str(uuid.uuid4()).split("-")[:2])

        checkpoint_file = cls.get_checkpoint_dir() / f"{prefix}_{suffix}.h5"
        print(f"Checkpoints for this run will be saved to: {checkpoint_file}")
        return checkpoint_file

    @contextmanager
    def _checkpoint_backup(self, preempt_ckpt: Path):
        old_ckpt = preempt_ckpt.with_name(f"{preempt_ckpt.stem}_old.h5")

        # Backup existing checkpoint
        if preempt_ckpt.exists():
            shutil.move(preempt_ckpt, old_ckpt)

        try:
            yield
        except Exception:
            # If new save fails, restore old checkpoint
            if old_ckpt.exists():
                shutil.move(old_ckpt, preempt_ckpt)
            raise
        else:
            # Success: remove old checkpoint
            if old_ckpt.exists():
                old_ckpt.unlink()

    def save_preempt_checkpoint(self):
        save_preempt = self.cfg.job.get("save_preempt", DEFAULT_SAVE_PREEMPT)
        if self.cfg.job.get("uuid") is not None and save_preempt:
            if self._preempt_state is None:
                print("No preempt state to save!")
                return

            preempt_ckpt = PREEMPT_DIR / f"{self.cfg.job.uuid}.h5"
            print("Saving preempt checkpoint...")

            with self._checkpoint_backup(preempt_ckpt):
                self._preempt_state.save_checkpoint(preempt_ckpt, 0)

            print("Preempt checkpoint saved!")

    def unlink_preempt_checkpoint(self):
        save_preempt = self.cfg.job.get("save_preempt", DEFAULT_SAVE_PREEMPT)
        if self.cfg.job.get("uuid") is not None and save_preempt:
            if self._preempt_state is None:
                print("No preempt state to save!")
                return

            preempt_ckpt = PREEMPT_DIR / f"{self.cfg.job.uuid}.h5"
            if preempt_ckpt.exists():
                print(f"Removing current preempted checkpoint: {preempt_ckpt}")
                preempt_ckpt.unlink()

    def _register_signal_handlers(self):
        save_preempt = self.cfg.job.get("save_preempt", DEFAULT_SAVE_PREEMPT)
        if self.cfg.job.get("uuid") is not None and save_preempt:
            print("Registering signal handlers for preemption...")
            for sig in [signal.SIGTERM, signal.SIGINT, signal.SIGHUP]:
                signal.signal(sig, self._handle_exit)

    def _handle_exit(self, signum, frame):
        signal_name = signal.Signals(signum).name
        print(f"[ModelTrainer] Received {signal_name} ({signum}), saving checkpoint...")
        self.save_preempt_checkpoint()

        if self._ctx is not None:
            for child_proc in self._ctx.active_children():
                child_proc.terminate()

        exit()
