import atexit
import contextlib
import functools
import logging
import os
import shutil
from collections import defaultdict
from collections.abc import Sequence
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import h5py
import jax
import numpy as np
import yaml
from seml.utils import merge_dicts
from wonderwords import RandomWord

import wandb
from neural_pfaffian.systems import Systems
from neural_pfaffian.utils import Modules
from neural_pfaffian.vmc import VMCState

if TYPE_CHECKING:
    from seml import Experiment


class LoggerAbc:
    def __init__(self, run_name: str, **_): ...

    def _log_data(self, data: dict[str, Any]): ...

    def _log_matrix(self, name: str, data: np.ndarray, prefix: str): ...

    def log_data(self, data: dict[str, Any], prefix: str | None = None):
        if prefix is None:
            return self._log_data(data)
        return self._log_data({f'{prefix}/{k}': v for k, v in data.items()})

    def log_matrix(self, name: str, data: np.ndarray, prefix: str | None = None):
        if prefix is None:
            prefix = 'logged_matrices'
        return self._log_matrix(name, data, prefix)

    def get_config_update(self) -> dict[str, Any]: ...

    def log_config(self, config: dict[str, Any]): ...

    def checkpoint(self, state: VMCState, systems: Systems): ...

    def load_checkpoint(
        self,
        state: VMCState,
        systems: Systems,
    ) -> tuple[VMCState, Systems]: ...

    def has_checkpoint(self) -> bool:
        return False

    def reschedule_hook(
        self,
        state: VMCState | None,
        batches: list[Systems] | None,
    ) -> dict[str, Any]: ...


class WandbLogger(LoggerAbc):
    def __init__(self, run_name: str, **kwargs):
        config: dict[str, Any] = {'name': run_name} | kwargs
        self.run = wandb.init(**config, resume='allow')

    def _log_data(self, data: dict[str, Any]):
        wandb.log(data)

    def _log_matrix(self, name: str, data: np.ndarray, prefix: str):
        pass

    def get_config_update(self):
        update = defaultdict(dict)
        update['logging']['wandb'] = {
            'id': self.run.id,
            'project': self.run.project,
            'entity': self.run.entity,
            'name': self.run.name,
        }
        return update

    def log_config(self, config: dict[str, Any]):
        if next(iter(config.keys())) not in self.run.config:
            self.run.config.update(config)

    def checkpoint(self, state: VMCState, systems: Systems):
        raise NotImplementedError

    def load_checkpoint(self, state: VMCState, systems: Systems):
        raise NotImplementedError

    def reschedule_hook(
        self,
        state: VMCState | None,
        batches: list[Systems] | None,
    ) -> dict[str, Any]:
        self.run.mark_preempting()
        return self.get_config_update()


class CsvLogFile:
    def __init__(self, path: Path | str, delimiter: str = ','):
        self.path = Path(path)
        self.path.parent.mkdir(parents=True, exist_ok=True)
        if self.path.exists():
            headers = self.path.open('r').readline().strip().split(',')
            if len(headers) == 0 or headers == ['']:
                headers = None
        else:
            headers = None
        self.headers = headers
        self.delimiter = delimiter
        self._logfile = open(self.path, 'a')  # noqa: SIM115

        atexit.register(self.close)

    def write(self, data: dict[str, Any]):
        if self.headers is None:
            self.headers = list(data.keys())
            self._logfile.write(self.delimiter.join(self.headers) + '\n')
        self._logfile.write(
            self.delimiter.join(str(data.get(h, '')) for h in self.headers) + '\n',
        )
        self._logfile.flush()

    def close(self):
        if not self._logfile.closed:
            self._logfile.close()

    def __del__(self):
        if not self._logfile.closed:
            self._logfile.close()


class H5MatrixLogFile:
    def __init__(self, path: Path | str, flush_every: int = 100):
        self.path = Path(path)
        self._file = h5py.File(self.path, 'a')
        # Dictionary to hold datasets created in this file.
        self._datasets: dict[str, h5py.Dataset] = {}
        self._flush_every = flush_every
        self._write_count = 0
        atexit.register(self.close)

    def write(self, dataset_name: str, data: np.ndarray):
        """
        Append the given numpy array to the dataset identified by dataset_name.
        If the dataset does not exist, it is created with an initial shape of (0, *data.shape)
        and is extendible along the first axis.
        """
        if self._file is None:
            raise ValueError('H5 file is already closed!')

        if dataset_name not in self._datasets:
            # If the dataset already exists in the file, use it.
            if dataset_name in self._file:
                obj = self._file[dataset_name]
                if not isinstance(obj, h5py.Dataset):
                    raise TypeError(
                        f"Found an object of type {type(obj)} named '{dataset_name}' that is not a dataset.",
                    )
                dset: h5py.Dataset = obj
            else:
                dset = self._file.create_dataset(
                    dataset_name,
                    shape=(0, *data.shape),
                    maxshape=(None, *data.shape),
                    chunks=(1, *data.shape),
                    compression='gzip',
                )
            self._datasets[dataset_name] = dset
        else:
            dset = self._datasets[dataset_name]

        # Append new data along axis 0.
        current_size = dset.shape[0]
        new_size = current_size + 1
        dset.resize(new_size, axis=0)
        dset[current_size] = data
        self._write_count += 1
        if self._write_count % self._flush_every == 0:
            # Flush the file every flush_every writes.
            self._file.flush()
            self._write_count = 0

    def close(self):
        if self._file is not None:
            self._file.flush()
            self._file.close()
            self._file = None

    def unlink(self):
        self.close()
        if self.path.exists():
            self.path.unlink()

    def __del__(self):
        self.close()


class FileLogger(LoggerAbc):
    def __init__(
        self,
        run_name: str,
        base_dir: Path | str = '.',
        directory: Path | str | None = None,
        delimiter: str = ';',
        save_interval: int = 1000,
        max_num_checkpoints: int = 5,
    ):
        if directory is None:
            directory = Path(base_dir) / str(
                run_name
                + datetime.now().strftime(
                    '_%Y-%m-%d_%H-%M-%S',
                ),
            )

        self.directory = Path(directory).resolve().absolute()
        self.directory.mkdir(parents=True, exist_ok=True)
        self.delimiter = delimiter
        self._csv_log_files: dict[str, CsvLogFile] = {}
        self._matrix_log_files: dict[str, H5MatrixLogFile] = {}

        self._save_interval = save_interval
        self._max_num_checkpoints = max_num_checkpoints
        self.ring_dir.mkdir(parents=True, exist_ok=True)

    @property
    def ring_dir(self) -> Path:
        return self.directory / 'ring'

    def _epoch_dir(self, base: Path, epoch: int) -> Path:
        return base / f'ep_{int(epoch):06d}'

    def _list_ring_epochs(self) -> list[int]:
        eps: list[int] = []
        if self.ring_dir.exists():
            for p in self.ring_dir.iterdir():
                if p.is_dir() and p.name.startswith('ep_'):
                    with contextlib.suppress(Exception):
                        eps.append(int(p.name.split('_')[1]))
        return sorted(eps)

    @property
    def config_path(self):
        return self.directory / 'config.yaml'

    @property
    def state_path(self):
        return self.directory / 'state.msgpack'

    @property
    def systems_path(self):
        return self.directory / 'systems.msgpack'

    def _atomic_write_dir(self, dst: Path, writer) -> None:
        tmp = dst.with_name(dst.name + '.tmp')
        if tmp.exists():
            shutil.rmtree(tmp)
        tmp.mkdir(parents=True, exist_ok=True)
        try:
            writer(tmp)
            os.replace(tmp, dst)  # atomic rename
        except Exception:
            with contextlib.suppress(Exception):
                shutil.rmtree(tmp)
            raise

    def _symlink_or_copy(self, src: Path, dst: Path) -> None:
        if dst.exists() or dst.is_symlink():
            with contextlib.suppress(Exception):
                dst.unlink()
        try:
            dst.symlink_to(src)
        except OSError:
            shutil.copy2(src, dst)

    def _update_latest_symlinks(self, src_dir: Path) -> None:
        """Keep backward-compatible resume files in the run root."""
        self._symlink_or_copy(src_dir / self.state_path.name, self.state_path)
        self._symlink_or_copy(src_dir / self.systems_path.name, self.systems_path)

    def logfile_path(self, prefix: str, filetype: Literal['csv', 'h5'] = 'csv'):
        return self.directory / f'{prefix}.{filetype}'

    def csv_logfile(self, prefix: str):
        if prefix not in self._csv_log_files:
            self._csv_log_files[prefix] = CsvLogFile(
                self.logfile_path(prefix),
                self.delimiter,
            )
        return self._csv_log_files[prefix]

    def h5_logfile(self, prefix: str):
        return self._matrix_log_files.setdefault(
            prefix,
            H5MatrixLogFile(self.logfile_path(prefix, 'h5')),
        )

    def get_config_update(self) -> dict[str, Any]:
        update = defaultdict(dict)
        update['logging']['file'] = {
            'directory': str(self.directory),
        }
        return update

    def log_config(self, config: dict[str, Any]):
        self.config_path.write_text(yaml.dump(config))

    def log_data(self, data: dict[str, Any], prefix: str | None = None):
        if prefix is None:
            prefix = 'main'
        self.csv_logfile(prefix).write(data)

    def _log_matrix(self, name: str, data: np.ndarray, prefix: str):
        matrix_log_file = self.h5_logfile(prefix)
        matrix_log_file.write(name, data)

    def checkpoint(self, state: VMCState, systems: Systems):
        """
        Ring-only checkpoint:
        - Writes (state, systems) to ring/ep_<epoch>/
        - Updates run-root 'latest' files (symlink/copy) to point to that dir
        - Trims the ring to at most ring_max by deleting oldest entries
        """
        epoch = int(state.epoch)

        ep_dir = self._epoch_dir(self.ring_dir, epoch)
        if ep_dir.exists():
            shutil.rmtree(ep_dir, ignore_errors=True)
        ep_dir.mkdir(parents=True, exist_ok=True)
        state.to_file(ep_dir / self.state_path.name)
        systems.to_file(ep_dir / self.systems_path.name)
        self._update_latest_symlinks(ep_dir)

        # Trim ring
        ring = self._list_ring_epochs()
        max_keep = self._max_num_checkpoints
        while len(ring) > max_keep:
            oldest = ring.pop(0)
            shutil.rmtree(self._epoch_dir(self.ring_dir, oldest), ignore_errors=True)

    def _no_ring_checkpoint(self, state: VMCState, systems: Systems):
        """Non-ring checkpoint: simply overwrite the run-root latest files."""
        state.to_file(self.state_path)
        systems.to_file(self.systems_path)

    def load_checkpoint(self, state: VMCState, systems: Systems):
        return state.from_file(self.state_path), systems.from_file(self.systems_path)

    def has_checkpoint(self) -> bool:
        return self.state_path.exists() and self.systems_path.exists()

    def should_save_epoch(self, epoch: int) -> bool:
        return (self._save_interval > 0) and (epoch % self._save_interval == 0)

    def rollback(self, state: VMCState, systems: Systems):
        current_epoch = int(state.epoch)
        candidates = [e for e in self._list_ring_epochs() if e <= current_epoch]
        candidates = sorted(set(candidates))
        if len(candidates) <= 1:
            raise RuntimeError('No checkpoint in ring before current epoch.')
        candidates = candidates[:-1]  # drop the newest candidate

        target_epoch = candidates[-1]
        target_dir = self._epoch_dir(self.ring_dir, target_epoch)

        # delete anything newer than the target so we don't thrash
        for e in self._list_ring_epochs():
            if e > target_epoch:
                shutil.rmtree(self._epoch_dir(self.ring_dir, e), ignore_errors=True)

        # Update latest and load
        self._update_latest_symlinks(target_dir)
        state = state.from_file(target_dir / self.state_path.name)
        systems = systems.from_file(target_dir / self.systems_path.name)
        return state, systems

    def reschedule_hook(
        self,
        state: VMCState | None,
        batches: list[Systems] | None,
    ) -> dict[str, Any]:
        # HACK
        self.h5_logfile('eval').unlink()
        if state is not None and batches is not None:
            self._no_ring_checkpoint(state, Systems.merge(batches))
        elif state is not None or batches is not None:
            logging.warning(
                'Reschedule hook was called with only one of state/systems!'
                ' Not saving checkpoint.',
            )
        return self.get_config_update()


class Logger:
    def __init__(
        self,
        system_name: str,
        logging_config,
        experiment: 'Experiment | None' = None,
    ):
        config = deepcopy(logging_config)

        # Generate a random name for the run
        word_gen = RandomWord()
        adj = word_gen.word(include_categories=['adjective'], word_max_length=6)
        noun = word_gen.word(include_categories=['noun'], word_max_length=6)
        run_name = f'{system_name}-{adj}-{noun}'

        # Update the logging config with the run name
        if isinstance(logging_config, dict):
            config = {k: v | {'run_name': run_name} for k, v in config.items()}
        elif isinstance(logging_config, Sequence):
            config = [
                (module[0], module[1] | {'run_name': run_name}) for module in config
            ]

        self.loggers = LOGGERS.try_init_many(config)

        if experiment:
            self.reschedule_hook = experiment.reschedule_hook(self._reschedule_hook)
        else:

            def _dummy_hook(
                state: VMCState | None = None,
                batches: list[Systems] | None = None,
            ):
                return None

            self.reschedule_hook = _dummy_hook

    def log(
        self,
        data: dict[str, Any],
        prefix: str | None = None,
        *,
        file_only: bool = False,
        delimiter: str = ',',
    ):
        data = jax.device_get(data)
        for logger in self.loggers:
            if not isinstance(logger, FileLogger) and file_only:
                continue
            logger.log_data(data, prefix)

    def log_matrix(
        self,
        name: str,
        data: np.ndarray,
        prefix: str | None = None,
        *,
        file_only: bool = False,
    ):
        data = jax.device_get(data)
        for logger in self.loggers:
            if not isinstance(logger, FileLogger) and file_only:
                continue
            logger.log_matrix(name, data, prefix)

    def log_directories(self) -> list[Path]:
        return [
            logger.directory for logger in self.loggers if isinstance(logger, FileLogger)
        ]

    def update_and_log_config(self, config: dict[str, Any]):
        """Retrieves all config updates from loggers and issues all loggers
        to save the updated config."""
        updates = []
        for logger in self.loggers:
            updates.append(logger.get_config_update())

        config = functools.reduce(merge_dicts, updates, config)
        for logger in self.loggers:
            logger.log_config(config)

    def should_save_checkpoint(self, epoch: int) -> bool:
        for logger in self.loggers:
            if isinstance(logger, FileLogger):
                return logger.should_save_epoch(epoch)
        return False

    def checkpoint(self, state: VMCState, systems: Systems):
        for logger in self.loggers:
            with contextlib.suppress(NotImplementedError):
                logger.checkpoint(state, systems)

    def load_checkpoint(self, state: VMCState, systems: Systems):
        for logger in self.loggers:
            if logger.has_checkpoint():
                return logger.load_checkpoint(state, systems)
        return state, systems

    def rollback(self, state: VMCState, systems: Systems):
        for logger in self.loggers:
            if isinstance(logger, FileLogger):
                try:
                    return logger.rollback(state, systems)
                except Exception:
                    pass
        raise RuntimeError('No logger could perform a rollback!')

    def has_checkpoint(self) -> bool:
        return any(logger.has_checkpoint() for logger in self.loggers)

    def _reschedule_hook(
        self,
        state: VMCState | None = None,
        batches: list[Systems] | None = None,
    ) -> dict[str, Any]:
        _config_updates: list[dict[str, Any]] = []
        for logger in self.loggers:
            _config_updates.append(logger.reschedule_hook(state, batches))
        config_update = functools.reduce(merge_dicts, _config_updates, defaultdict(dict))
        return config_update


LOGGERS = Modules[LoggerAbc](
    {
        cls.__name__.lower().replace('logger', ''): cls
        for cls in [WandbLogger, FileLogger]
    },
)


class TraceController:
    def __init__(
        self,
        config: dict[str, Any] | None,
        log_directories: list[Path],
    ) -> None:
        trace_cfg = (config or {}).get('trace') or {}
        enabled = trace_cfg.get('enabled', True)
        steps = trace_cfg.get('steps')
        start = trace_cfg.get('start_step')
        duration = trace_cfg.get('duration_steps')

        if not enabled:
            self._intervals: list[tuple[int, int | None]] = []
        else:
            if steps:
                sorted_steps = sorted({int(step) for step in steps})
                self._intervals = []
                if sorted_steps:
                    interval_start = sorted_steps[0]
                    prev = sorted_steps[0]
                    for current in sorted_steps[1:]:
                        if current != prev + 1:
                            self._intervals.append((interval_start, prev + 1))
                            interval_start = current
                        prev = current
                    self._intervals.append((interval_start, prev + 1))
            else:
                start_i = 0 if start is None else int(start)
                if duration is None:
                    self._intervals = [(start_i, None)]
                else:
                    self._intervals = [
                        (start_i, start_i + max(int(duration), 1)),
                    ]

        self._subdir = trace_cfg.get('subdir', 'trace')
        self._base_dir = self._resolve_output_dir(config or {}, log_directories)
        self._active_context: contextlib.AbstractContextManager | None = None
        self._current_interval: tuple[int, int | None] | None = None
        if self._intervals:
            logging.info('Tracing enabled for intervals: %s', self._intervals)
        else:
            logging.info('Tracing disabled or no steps configured')

    def _resolve_output_dir(
        self,
        config: dict[str, Any],
        log_directories: list[Path],
    ) -> Path:
        configured = config.get('output_dir')
        if configured is not None:
            base = Path(configured)
        elif log_directories:
            base = log_directories[0] / 'profiles'
        else:
            base = Path.cwd() / 'profiles'
        base.mkdir(parents=True, exist_ok=True)
        return base

    @contextmanager
    def trace(self, step: int):
        interval = self._interval_for(step)
        if interval is None:
            if self._active_context is not None:
                self._close()
            yield
            return
        start, end = interval
        if self._active_context is None:
            end_label = f'{end - 1:06d}' if end is not None else 'end'
            trace_dir = self._base_dir / f'{self._subdir}_steps_{start:06d}_{end_label}'
            trace_dir.mkdir(parents=True, exist_ok=True)
            logging.info('Tracing steps [%d, %s) to %s', start, end_label, trace_dir)
            ctx = jax.profiler.trace(str(trace_dir))
            ctx.__enter__()
            self._active_context = ctx
            self._current_interval = interval
        try:
            yield
        finally:
            if (
                self._current_interval is not None
                and self._current_interval[1] is not None
                and step + 1 >= self._current_interval[1]
            ):
                self._close()

    def _interval_for(self, step: int) -> tuple[int, int | None] | None:
        for start, end in self._intervals:
            if start <= step and (end is None or step < end):
                return (start, end)
        return None

    def _close(self) -> None:
        if self._active_context is None:
            return
        start, end = self._current_interval or (0, None)
        self._active_context.__exit__(None, None, None)
        end_label = f'{end}' if end is not None else 'end'
        logging.info('Finished tracing steps [%d, %s)', start, end_label)
        self._active_context = None
        self._current_interval = None

    def close(self) -> None:
        self._close()


def make_trace_controller(
    config: dict[str, Any] | None,
    log_directories: list[Path],
) -> TraceController:
    return TraceController(config, log_directories)
