"""Results of perturbation experiments."""
import json
import os
from typing import Optional, Sequence

import torch


###############################################################################


class ExamplesPerturbationInfo:
    """Information about the perturbations over a set of examples."""

    def __init__(
        self, *,
        kl: float,
        changed_prediction_fraction: float,
    ):
        self.kl = kl
        self.changed_prediction_fraction = changed_prediction_fraction

    #######################################################

    def to_json(self):
        return {'kl': float(self.kl), 'changed_prediction_fraction': float(self.changed_prediction_fraction)}

    @classmethod
    def from_json(cls, json_obj):
        return cls(kl=json_obj['kl'], changed_prediction_fraction=json_obj['changed_prediction_fraction'])


class PerturbationRunInfo:

    def __init__(
        self, *,
        top_examples_info: 'ExamplesPerturbationInfo',
        baseline_examples_info: 'ExamplesPerturbationInfo',
    ):
        self.top_examples_info = top_examples_info
        self.baseline_examples_info = baseline_examples_info

    #######################################################

    @property
    def kl_ratio(self) -> float:
        """Returns the ratio of the KL-divergence for the top examples to the baseline examples."""
        return self.top_examples_info.kl / self.baseline_examples_info.kl
    
    #######################################################

    def to_json(self):
        return {
            'top_examples_info': self.top_examples_info.to_json(),
            'baseline_examples_info': self.baseline_examples_info.to_json(),
        }

    @classmethod
    def from_json(cls, json_obj):
        return cls(
            top_examples_info=ExamplesPerturbationInfo.from_json(json_obj['top_examples_info']),
            baseline_examples_info=ExamplesPerturbationInfo.from_json(json_obj['baseline_examples_info']),
        )


###############################################################################


class ComponentPerturbationResults:
    """Results for a perturbation for a single component."""

    def __init__(
        self, *,

        component_index: int,

        # Must be positive.
        perturbation_magnitude: float,

        # Results from the same perturbation but with a different sign.
        plus_results: 'PerturbationRunInfo',
        minus_results: 'PerturbationRunInfo',
    ):
        self.component_index = component_index

        self.perturbation_magnitude = perturbation_magnitude

        self.plus_results = plus_results
        self.minus_results = minus_results

    #######################################################

    @property
    def greatest_kl_ratio(self) -> float:
        """Returns greatest kl ratio of the plus/minus results."""
        return max(self.plus_results.kl_ratio, self.minus_results.kl_ratio)

    @property
    def greatest_kl_ratio_results(self) -> 'PerturbationRunInfo':
        if self.plus_results.kl_ratio > self.minus_results.kl_ratio:
            return self.plus_results
        else:
            return self.minus_results

    @property
    def greatest_kl_ratio_top_examples_kl(self) -> float:
        return self.greatest_kl_ratio_results.top_examples_info.kl
    
    @property
    def greatest_kl_ratio_baseline_examples_kl(self) -> float:
        return self.greatest_kl_ratio_results.baseline_examples_info.kl
    
    #######################################################

    def to_json(self):
        return {
            'component_index': int(self.component_index),
            'perturbation_magnitude': float(self.perturbation_magnitude),
            'plus_results': self.plus_results.to_json(),
            'minus_results': self.minus_results.to_json(),
        }

    @classmethod
    def from_json(cls, json_obj):
        return cls(
            component_index=json_obj['component_index'],
            perturbation_magnitude=json_obj['perturbation_magnitude'],
            plus_results=PerturbationRunInfo.from_json(json_obj['plus_results']),
            minus_results=PerturbationRunInfo.from_json(json_obj['minus_results']),
        )


class ExperimentPerturbationResults:
    """Results for a set of perturbations."""

    def __init__(
        self, *,
        # TODO: Decide if these have to be unique for each component.
        component_results: Sequence['ComponentPerturbationResults'],
    ):
        self.component_results = tuple(component_results)

    #######################################################

    def get_n_components(self) -> int:
        return len(self.component_results)

    def get_greatest_kl_ratios(self) -> torch.Tensor:
        """Returns a vector of greatest kl ratios with an entry for each component result."""
        return torch.tensor([
            r.greatest_kl_ratio for r in self.component_results
        ], dtype=torch.float32)

    def get_greatest_kl_ratio_top_examples_kls(self) -> torch.Tensor:
        return torch.tensor([
            r.greatest_kl_ratio_top_examples_kl for r in self.component_results
        ], dtype=torch.float32)

    def get_greatest_kl_ratio_baseline_examples_kls(self) -> torch.Tensor:
        return torch.tensor([
            r.greatest_kl_ratio_baseline_examples_kl for r in self.component_results
        ], dtype=torch.float32)

    #######################################################

    def to_json(self):
        return {'component_results': [cr.to_json() for cr in self.component_results]}

    @classmethod
    def from_json(cls, json_obj):
        return cls(component_results=[ComponentPerturbationResults.from_json(cr) for cr in json_obj['component_results']])

    #######################################################

    def save(self, filepath: str):
        with open(os.path.expanduser(filepath), "w") as f:
            json.dump(self.to_json(), f)
