"""Runs the LRM-NPEFF decomposition."""
import math
import os
import time
from typing import List, Optional

import h5py
import numpy as np
import torch

from npeff_torch.util import hdf5_utils

###############################################################################


def normalize_pefs_in_place(
    # shape = [n_examples, rank, n_parameters]
    pefs: torch.Tensor,
    # shape = [n_examples]
    pef_frobenius_norms: torch.Tensor,
):
    pefs /= torch.sqrt(pef_frobenius_norms[:, None, None])


###############################################################################

class LrmNpeffRunner:

    @torch.no_grad()
    def __init__(
        self,
        # These should already be normalized.
        # shape = [n_examples, rank, n_parameters]
        pefs: torch.Tensor,
        *,
        n_components: int,
        
        seed: int,

        mu_eps: float,

        learning_rate_G_G_only: float,
        learning_rate_G_G_joint: float,

        n_iters_G_only: int,
        n_iters_joint: int,

        log_loss_frequency: int,

        # NOTE: These are tolerances for the change in loss of 'log_loss_frequency' steps, NOT every step.

        # If both relative and absolute tolerances are specified, then training will stop if either are satisfired.

        # If the loss after 'log_loss_frequency' steps drops by less than this, then stop.
        abs_tol_G_only: Optional[float] = None,
        # If the drop in loss after 'log_loss_frequency' steps is less than this fraction of the loss before the
        # 'log_loss_frequency' steps, then stop.
        rel_tol_G_only: Optional[float] = None,

        abs_tol_joint: Optional[float] = None,
        rel_tol_joint: Optional[float] = None,
    ):
        self.pefs = pefs
        self.device = pefs.device

        self.mu_eps = mu_eps
        self.learning_rate_G_G_only = learning_rate_G_G_only
        self.learning_rate_G_G_joint = learning_rate_G_G_joint
        self.n_iters_G_only = n_iters_G_only
        self.n_iters_joint = n_iters_joint
        self.log_loss_frequency = log_loss_frequency

        self.abs_tol_G_only = abs_tol_G_only
        self.rel_tol_G_only = rel_tol_G_only
        self.abs_tol_joint = abs_tol_joint
        self.rel_tol_joint = rel_tol_joint

        n_examples, _, n_parameters = pefs.shape

        gen = torch.Generator(self.device)
        gen.manual_seed(seed)

        self.W = torch.rand(n_examples, n_components, generator=gen, dtype=pefs.dtype, device=self.device)

        self.G = torch.randn(n_components, n_parameters, generator=gen, dtype=pefs.dtype, device=self.device)
        self.G /= math.sqrt(self.G.numel() / 2)

        # Intermediates shared between steps.
        self.WW: torch.Tensor = None
        self.GG: torch.Tensor = None
        # I'm pretty sure I refer to this as B in the paper.
        self.AG: torch.Tensor = None

        self.t_step_start_ns: int = None

        self.tr_xx = self._compute_tr_xx(pefs)

        self.losses_G_only = []
        self.losses_joint = []

    #######################################################

    def _compute_tr_xx(self, pefs: torch.Tensor) -> torch.Tensor:
        # Returns a scalar tensor.
        dps = torch.einsum('erp,eqp->erq', pefs, pefs)
        return torch.einsum('erq,erq->', dps, dps)

    #######################################################

    def _should_stop(self, losses: List[float], loss: float, *, abs_tol: Optional[float], rel_tol: Optional[float]) -> bool:
        if not len(losses):
            return False

        last_loss = losses[-1]
        loss_drop = last_loss - loss

        if abs_tol is not None:
            if loss_drop < abs_tol:
                return True

        if rel_tol is not None:
            if loss_drop < last_loss * rel_tol:
                return True

        return False

    #######################################################

    def _compute_WW(self):
        self.WW = torch.einsum('ec,ek->ck', self.W, self.W)

    def _compute_AG(self):
        self.AG = torch.einsum('ijl,kl->ijk', self.pefs, self.G)

    def _compute_GG(self):
        self.GG = torch.einsum('cp,kp->ck', self.G, self.G)

    #######################################################

    def _W_update_step(
        self, *,
        recompute_AG: bool,
        recompute_GG: bool,
    ):
        if recompute_AG:
            self._compute_AG()
        if recompute_GG:
            self._compute_GG()

        N = torch.einsum('ijk,ijk->ik', self.AG, self.AG)

        D = torch.einsum('ec,ck->ek', self.W, torch.square(self.GG))
        D += self.mu_eps

        self.W *= (N / D)

    def _G_update_step(
        self, *,
        learning_rate: float,
        recompute_WW: bool,
        recompute_AG: bool,
        recompute_GG: bool,
    ):
        if recompute_WW:
            self._compute_WW()
        if recompute_AG:
            self._compute_AG()
        if recompute_GG:
            self._compute_GG()

        # NOTE: These differ by a factor of 4 from paper. The factor of 4 will be used to modify the learning rate.
        T1 = torch.einsum('ck,kp->cp', self.WW * self.GG, self.G)
        neg_T2 = torch.einsum('ji,jki,jkl->il', self.W, self.AG, self.pefs)
        
        self.G -= (4.0 * learning_rate) * (T1 - neg_T2)

    #######################################################

    def _compute_loss(
        self, *,
        # recompute_WW: bool,
        recompute_AG: bool,
        recompute_GG: bool,
    ) -> float:
        # Returns a scalar tensor.
        # if recompute_WW:
        #     self._compute_WW()
        if recompute_AG:
            self._compute_AG()
        if recompute_GG:
            self._compute_GG()

        self._compute_WW()

        tr_WW_HH = torch.einsum('ij,ij->', self.WW, torch.square(self.GG))

        N = torch.einsum('ijk,ijk->ik', self.AG, self.AG)
        tr_WHX = torch.einsum('ij,ij->', self.W, N)

        return float((self.tr_xx - 2.0 * tr_WHX + tr_WW_HH).detach().cpu().numpy())

    def _should_compute_loss_at_step(self, step: int) -> bool:
        return ((step + 1) % self.log_loss_frequency) == 0

    def _log_loss(self, prefix: str, step: int, loss: float):
        t_end_ns = time.time_ns()

        elapsed_ms = (t_end_ns - self.t_step_start_ns) / 1e6

        print(f'{prefix} step {step + 1}: {loss} [{elapsed_ms / self.log_loss_frequency} ms/step]')

        self.t_step_start_ns = time.time_ns()

    #######################################################
    
    def _run_G_only(self):
        opts_loss = {'recompute_AG': True, 'recompute_GG': True}
        opts_G = {'recompute_AG': True, 'recompute_GG': True, 'recompute_WW': True}

        self.t_step_start_ns = time.time_ns()
        for step in range(self.n_iters_G_only):
            self._G_update_step(learning_rate=self.learning_rate_G_G_only, **opts_G)

            if self._should_compute_loss_at_step(step):
                loss = self._compute_loss(**opts_loss)
                should_stop = self._should_stop(self.losses_G_only, loss, abs_tol=self.abs_tol_G_only, rel_tol=self.rel_tol_G_only)
                self.losses_G_only.append(loss)
                self._log_loss("G_only", step, loss)
                if should_stop:
                    break

    def _run_joint(self):
        opts_loss = {'recompute_AG': False, 'recompute_GG': False}
        opts_G = {'recompute_AG': True, 'recompute_GG': True, 'recompute_WW': True}
        opts_W = {'recompute_AG': True, 'recompute_GG': True}

        self.t_step_start_ns = time.time_ns()
        for step in range(self.n_iters_joint):
            self._G_update_step(learning_rate=self.learning_rate_G_G_joint, **opts_G)
            self._W_update_step(**opts_W)

            if self._should_compute_loss_at_step(step):
                loss = self._compute_loss(**opts_loss)
                should_stop = self._should_stop(self.losses_joint, loss, abs_tol=self.abs_tol_joint, rel_tol=self.rel_tol_joint)
                self.losses_joint.append(loss)
                self._log_loss("joint", step, loss)
                if should_stop:
                    break

            opts_G['recompute_AG'] = False
            opts_G['recompute_GG'] = False

    #######################################################

    def run(self):
        self._run_G_only()
        self._run_joint()

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            data_grp = f.create_group('data')
            data_grp.attrs['n_parameters'] = self.pefs.shape[-1]
            data_grp.attrs['n_classes'] = self.pefs.shape[-2]
            hdf5_utils.save_h5_ds(data_grp, 'W', self.W.detach().cpu().numpy())
            hdf5_utils.save_h5_ds(data_grp, 'G', self.G.detach().cpu().numpy())

            losses_grp = f.create_group('losses')
            losses_grp.attrs['log_loss_frequency'] = self.log_loss_frequency
            hdf5_utils.save_h5_ds(losses_grp, 'losses_G_only', np.array(self.losses_G_only))
            hdf5_utils.save_h5_ds(losses_grp, 'losses_joint', np.array(self.losses_joint))
