import os.path
import warnings
from time import time
from typing import Dict, List, Protocol

import numpy as onp
import pandas as pd
import wandb


class _LossTupleLike(Protocol):
    xc_energy: float
    forces: float
    xc_potential: float
    orbital_rotation_gradient: float
    orbital_rotation_hessian: float
    total_energy: float  # only in dynamic training stage
    density: float  # only in dynamic training stage


class _RelativeLossWeightsTupleLike(Protocol):
    xc_energy: float
    forces: float
    xc_potential: float
    orbital_rotation_hessian: float
    orbital_rotation_gradient: float
    total_energy: float  # only in dynamic training stage
    density: float  # only in dynamic training stage


Scalar = float | int


class Logger:
    aggregated_data: Dict[str, pd.DataFrame]
    write_csv: bool = False
    epoch_start_time: float | None = None
    # backend: str = 'wandb'  # TODO: add support for Tensorboard?

    def __init__(self, project: str, config: Dict, dir: str, name=None) -> None:
        dir = os.path.join(dir, 'wandb')
        wandb.init(project=project, config=config, dir=dir, name=name)  # type: ignore
        self.accumulate = False
        self.accumulate_keys: Dict[str, List] = {}

    def log(self, dict: Dict[str, Scalar], step: int | None = None) -> None:
        wandb.log(dict, step=step)
        if self.accumulate:
            for key, value in dict.items():
                if key in self.accumulate_keys:
                    self.accumulate_keys[key].append(value)

    def start_mean(self, keys: List[str]) -> None:
        self.accumulate_keys = {key: [] for key in keys}
        self.accumulate = True

    def _evaluate_mean(self, key=None) -> None:
        for key, values in self.accumulate_keys.items():
            if len(values) == 0:
                continue
            else:
                label = f'mean/{key}'
                self.log({label: sum(values) / len(values)})

    def get_current_mean(self, key: str, max_nans=5) -> float:
        values = self.accumulate_keys[key]
        array = onp.array(values)
        isnan = onp.isnan(array)
        if isnan.sum() > max_nans:
            out = onp.nan
        else:
            out = array[~isnan].mean()
        return out

    def overlap(self, condition_number: float, smallest_eigenvalue: float) -> None:
        self.log(
            {
                'debug/overlap/condition_number': condition_number,
                'debug/overlap/smallest_eigenvalue': smallest_eigenvalue,
            }
        )

    def stop_mean(self) -> None:
        self._evaluate_mean()
        self.accumulate = False
        self.accumulate_keys = {}

    def start_epoch(self, e: int, prefix: str = '') -> None:
        print('#' * 20, f'{prefix} Epoch {e}:', flush=True)
        current_time = time()
        if self.epoch_start_time is not None:
            time_diff = current_time - self.epoch_start_time
            self.log({f'benchmark/{prefix}/Epoch run duration [min]': time_diff / 60})
        self.epoch_start_time = current_time

    def epoch_training_duration(
        self, n_samples: int | None = None, prefix: str = ''
    ) -> None:
        current_time = time()
        time_diff = current_time - self.epoch_start_time  # type: ignore
        self.log({f'benchmark/{prefix}/Epoch train duration [min]': time_diff / 60})
        if n_samples is not None:
            self.log(
                {f'benchmark/{prefix}/Training samples per second': n_samples / time_diff}
            )

    def benchmark_start(self, prefix: str) -> None:
        self.__benchmark_prefix = prefix
        self.__benchmark_start_time = time()

    def benchmark(self, name: str) -> None:
        current_time = time()
        time_diff = current_time - self.__benchmark_start_time  # type: ignore
        self.__benchmark_start_time = current_time
        self.log(
            {
                f'benchmark/{self.__benchmark_prefix}/{name} duration [ms]': time_diff
                * 1000
            }
        )

    def deixc_loss(
        self,
        loss: float,
        e_pred: float | None,
        e_xc_pred: float,
        e_target: float,
        e_xc_target: float | None,
        loss_components: _LossTupleLike,
        relative_loss_weights: _RelativeLossWeightsTupleLike,
        prefix: str,
        idx: int,
        volatility: float | None,
        max_energy_volatility: float | None,
    ) -> None:
        self._loss(
            loss,
            e_pred,
            e_xc_pred,
            e_target,
            e_xc_target,
            loss_components,
            relative_loss_weights,
            prefix,
            idx,
            volatility,
            max_energy_volatility,
        )

    def egxc_loss(
        self,
        loss: float,
        e_total_pred: float,
        e_xc_pred: float,
        e_total_target: float,
        prefix: str,
        idx: int,
        volatility: float,
        max_energy_volatility: float,
    ) -> None:
        """
        Logs loss, energy error, and volatility.
        Args:
            loss: Computed loss value.
            e_pred: Predicted total energy [Eh].
            e_xc_pred: Predicted XC energy [Eh].
            e_target: Target/reference energy [Eh] (can be either total or XC energy).
            prefix: Logging prefix (e.g., 'train', 'val').
            idx: Index of the sample in the dataset.
            volatility: Energy volatility [mEh] (optional).
            max_energy_volatility: Maximum allowed volatility for updates (default: inf).
        """
        self._loss(
            loss,
            e_total_pred,
            e_xc_pred,
            e_total_target,
            None,
            None,
            None,
            prefix,
            idx,
            volatility,
            max_energy_volatility,
        )

    def _loss(
        self,
        loss: float,
        e_total_pred: float | None,
        e_xc_pred: float,
        e_total_target: float,
        e_xc_target: float | None,
        loss_components: _LossTupleLike | None,
        relative_loss_weights: _RelativeLossWeightsTupleLike | None,
        prefix: str,
        idx: int,
        volatility: float | None,
        max_energy_volatility: float | None,
    ) -> None:
        """
        Logs loss, energy error, and volatility.
        Args:
            loss: Computed loss value.
            e_pred: Predicted total energy [Eh].
            e_xc_pred: Predicted XC energy [Eh].
            e_target: Target/reference energy [Eh].
            e_xc_target: Target/reference XC energy [Eh].
            loss_components: LossComponents object with individual loss terms.
            prefix: Logging prefix (e.g., 'train', 'val').
            idx: Index of the sample in the dataset.
            volatility: Energy volatility [mEh] (optional).
            max_energy_volatility: Maximum allowed volatility for updates (default: inf).
        """

        if max_energy_volatility is not None:
            assert volatility is not None, (
                'Volatility must be provided if max_energy_volatility is provided'
            )
            if volatility > max_energy_volatility:
                warnings.warn(
                    f'Energy volatility of sample {idx} is high: {volatility} mEh. \
                    SCF may not have converged and parameters will not be updated.'
                )

        log_dict = {
            f'{prefix}/loss': loss,
            f'debug/{prefix}/volatility [mEh]': volatility,
            f'debug/{prefix}/xc energy [Eh]': e_xc_pred,
        }

        if e_total_pred is not None:
            log_dict[f'{prefix}/total energy error [mEh]'] = (
                abs(e_total_pred - e_total_target) * 1e3
            )
            log_dict[f'debug/{prefix}/energy delta [mEh]'] = (
                e_total_pred - e_total_target
            ) * 1e3

        if e_xc_target is not None:
            log_dict[f'{prefix}/xc energy error [mEh]'] = (
                abs(e_xc_pred - e_xc_target) * 1e3
            )
            log_dict[f'debug/{prefix}/xc energy delta [mEh]'] = (
                e_xc_pred - e_xc_target
            ) * 1e3

        if loss_components is not None:
            log_dict = log_dict | {
                f'loss/{prefix}/xc_energy': loss_components.xc_energy,
                f'loss/{prefix}/forces': loss_components.forces,
                f'loss/{prefix}/xc_potential': loss_components.xc_potential,
                f'loss/{prefix}/orbital_rotation_gradient': loss_components.orbital_rotation_gradient,
                f'loss/{prefix}/orbital_rotation_hessian': loss_components.orbital_rotation_hessian,
                f'loss/{prefix}/total_energy': loss_components.total_energy,
                f'loss/{prefix}/density': loss_components.density,
            }
            if relative_loss_weights is not None:
                log_dict = log_dict | {
                    f'weighted_loss/{prefix}/xc_energy': loss_components.xc_energy
                    * relative_loss_weights.xc_energy,
                    f'weighted_loss/{prefix}/forces': loss_components.forces
                    * relative_loss_weights.forces,
                    f'weighted_loss/{prefix}/xc_potential': loss_components.xc_potential
                    * relative_loss_weights.xc_potential,
                    f'weighted_loss/{prefix}/orbital_rotation_gradient': loss_components.orbital_rotation_gradient
                    * relative_loss_weights.orbital_rotation_gradient,
                    f'weighted_loss/{prefix}/orbital_rotation_hessian': loss_components.orbital_rotation_hessian
                    * relative_loss_weights.orbital_rotation_hessian,
                    f'weighted_loss/{prefix}/total_energy': loss_components.total_energy
                    * relative_loss_weights.total_energy,
                    f'weighted_loss/{prefix}/density': loss_components.density
                    * relative_loss_weights.density,
                }

        self.log(log_dict)

    def updates(
        self,
        grad_norm: float,
        update_norm: float,
        prefix: str,
    ) -> None:
        self.log(
            {
                f'debug/{prefix}/grad norm': grad_norm,
                f'debug/{prefix}/update norm': update_norm,
            }
        )
