"""Common code for perturbation experiments."""
import dataclasses
import json
import os
from typing import List, Optional, Union

import numpy as np
from scipy.stats import gmean
import tensorflow as tf
from transformers import PreTrainedTokenizer

from em import datasets as em_datasets
from em.fishers import lrm_pefs
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.pi import qqp_components_context as QCC

###############################################################################


@dataclasses.dataclass
class PerturbationExperiment:
    nmf: lrm_npeff.LrmNpeffDecomposition
    logits: np.ndarray

    model: tf.keras.Model
    variables: List[tf.Variable]

    task: str
    split: str

    n_top_examples: int
    n_total_examples: int

    # Required for text tasks. Set to None for image tasks.
    tokenizer: Optional[PreTrainedTokenizer] = None

    # Controls image size for image tasks.
    sequence_length: int = 128

    def __post_init__(self):
        assert self.nmf.W.shape[0] == self.logits.shape[0]

        self.n_examples = self.nmf.W.shape[0]

        self.og_variables = [tf.identity(v) for v in self.variables]

        self.eval_ctx = QCC.EvaluationContext2.create_from_ds_and_logits(
            ds=em_datasets.load(
                self.task, split=self.split, sequence_length=self.sequence_length, tokenizer=self.tokenizer),
            logits=self.logits,
        )

        # Normalize the components here just to be safe.
        self.nmf.normalize_components_to_unit_norm()

        # Precompute this for faster per-component runs downs the line.
        self.abs_cos_sims = np.abs(self.nmf.G @ self.nmf.G.T)

    @classmethod
    def from_filepaths(cls, nmf_filepath: str, pef_filepath: str, **kwargs):
        nmf = lrm_npeff.LrmNpeffDecomposition.load(nmf_filepath, read_G=True)
        logits = lrm_pefs.SparseLrmPefs.load_logits(pef_filepath)
        return cls(nmf=nmf, logits=logits, **kwargs)


###############################################################################


@dataclasses.dataclass
class PerturbationResults:
    top_results: QCC.QqpEvaluationResults
    total_results: QCC.QqpEvaluationResults
    
    def ratio(self):
        return self.top_results.kl() / self.total_results.kl()


@dataclasses.dataclass
class PmPerturbationResults:
    magnitude: float
    plus_results: PerturbationResults
    minus_results: PerturbationResults


@dataclasses.dataclass
class ComponentPerturber:
    exp: PerturbationExperiment

    component_index: int

    # Set to None for no semi-orthogonalization.
    max_sim: Union[float, None]

    def __post_init__(self):
        if self.max_sim is not None:
            assert 0.0 < self.max_sim < 1.0

        self.nmf = self.exp.nmf
        self.abs_cos_sims = self.exp.abs_cos_sims

        self.model = self.exp.model
        self.variables = self.exp.variables
        self.og_variables = self.exp.og_variables

        self.eval_ctx = self.exp.eval_ctx

        self.packer = flat_pack.FlatPacker([v.shape for v in self.variables])

        self._top_inds = np.argsort(-self.nmf.W[:, self.component_index])[:self.exp.n_top_examples]
        self._total_inds = np.arange(self.exp.n_total_examples)

        self._normalized_perturbation = None

    @property
    def normalized_perturbation(self) -> List[tf.Tensor]:
        if self._normalized_perturbation is None:
            # Assumes rows of G have unit norm.
            G = self.nmf.G

            g_main = np.copy(G[self.component_index])

            if self.max_sim is not None:
                for i in range(G.shape[0]):
                    if i == self.component_index:
                        continue
                    if self.abs_cos_sims[self.component_index, i] > self.max_sim:
                        continue
                    g_main -= g_main.dot(G[i]) * G[i]

            g_main /= np.sqrt(np.sum(g_main**2))

            g = np.zeros([self.nmf.n_parameters], dtype=np.float32)
            g[self.nmf.new_to_old_col_indices] = g_main

            self._normalized_perturbation = self.packer.decode_tf(tf.cast(g, tf.float32))

        return self._normalized_perturbation

    def _perturb_weights(self, multiplier: float):
        for ogv, v, offset in zip(self.og_variables, self.variables, self.normalized_perturbation):
            v.assign(ogv + multiplier * offset)

    def _evaluate_model(self) -> PerturbationResults:
        top_results = self.eval_ctx.evaluate(self.model, self._top_inds)
        total_results = self.eval_ctx.evaluate(self.model, self._total_inds)
        return PerturbationResults(top_results=top_results, total_results=total_results)

    def evaluate_pm(self, magnitude: float) -> PmPerturbationResults:
        """Evaluates a perturbation for both +/- magnitude * normalized_perturbation.
        """
        assert magnitude >= 0.0

        self._perturb_weights(magnitude)
        plus_results = self._evaluate_model()

        self._perturb_weights(-magnitude)
        minus_results = self._evaluate_model()

        return PmPerturbationResults(
            magnitude=magnitude,
            plus_results=plus_results,
            minus_results=minus_results,
        )


###############################################################################

@dataclasses.dataclass
class PerturbationStats:
    top_kl: float
    top_loss: float
    top_acc: float

    total_kl: float
    total_loss: float
    total_acc: float

    def kl_ratio(self) -> float:
        return self.top_kl / self.total_kl


@dataclasses.dataclass
class ComponentPmPerturbationOutput:
    component_index: int

    plus_results: PerturbationStats
    minus_results: PerturbationStats

    def kl_ratio_max(self) -> float:
        return max(self.plus_results.kl_ratio(), self.minus_results.kl_ratio())

    def kl_ratio_min(self) -> float:
        return min(self.plus_results.kl_ratio(), self.minus_results.kl_ratio())

    def kl_ratio_amean(self) -> float:
        return sum([self.plus_results.kl_ratio(), self.minus_results.kl_ratio()]) / 2

    def kl_ratio_gmean(self) -> float:
        return gmean([self.plus_results.kl_ratio(), self.minus_results.kl_ratio()])

    @classmethod
    def from_json(cls, obj):
        return cls(
            component_index=obj['component_index'],
            plus_results=PerturbationStats(**obj['plus_results']),
            minus_results=PerturbationStats(**obj['minus_results']),
        )


@dataclasses.dataclass
class PmPerturbationExperimentOutput:
    model_name: str
    nmf_filepath: str

    task: str
    split: str

    n_top_examples: int
    n_total_examples: int

    max_sim: Union[float, None]
    magnitude: float

    component_outputs: List[ComponentPmPerturbationOutput]

    def to_json(self):
        return dataclasses.asdict(self)

    def make_kl_ratios_csv_str(self) -> str:
        ret = [['comp', 'max', 'min', 'amean', 'gmean']]
        for output in self.component_outputs:
            ret.append([
                output.component_index,
                output.kl_ratio_max(),
                output.kl_ratio_min(),
                output.kl_ratio_amean(),
                output.kl_ratio_gmean(),
            ])
        return '\n'.join([','.join([str(cell) for cell in row]) for row in ret])

    def save(self, filepath: str):
        with open(os.path.expanduser(filepath), "w") as f:
            json.dump(self.to_json(), f)

    @classmethod
    def load(cls, filepath: str):
        with open(os.path.expanduser(filepath), "r") as f:
            obj = json.load(f)
        obj['component_outputs'] = [
            ComponentPmPerturbationOutput.from_json(x)
            for x in obj['component_outputs']
        ]
        return cls(**obj)
