"""Stuff for experiments about ablation."""
import dataclasses
import json
import os
from typing import List, Optional, Sequence

import tensorflow as tf

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hdf5_util
from em.util import hf_util

from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC


###############################################################################

# TODO: Hardcoding some stuff now, maybe make settable later.
EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################


@dataclasses.dataclass
class AblationExperimentConfig:
    component_index: int

    # n_examples_sign_guide: int

    n_coefficients: int

    compute_fisher_batch_size: int = 4
    compute_gradient_batch_size: int = 32

    fisher_floor: float = 1e-7


@dataclasses.dataclass
class AblationExperiment:
    config: AblationExperimentConfig
    mc: QCC.QqpModelContext

    dense_fisher: Sequence[tf.Tensor]

    extra_eval_contexts: Optional[Sequence[QCC.EvaluationContext2]] = None

    def __post_init__(self):
        self.component_index = self.config.component_index

        self.tokenizer = self.mc.tokenizer
        self.eval_context = self.mc.get_evaluation_context()
        self.sign_guider = QMC.SignGuider(mc=self.mc)

        self.sorted_example_indices = self.mc.sort_example_indices_for_component(self.component_index)

        if self.extra_eval_contexts is None:
            self.extra_eval_contexts = []

        self.all_eval_contexts = [self.eval_context, *self.extra_eval_contexts]

        self.retaining_variables = list(self.mc.variables)

    def get_ablating_variables(self, sign_guide: Sequence[tf.Tensor], delta: float):
        return self.sign_guider.apply_sign_guide(self.retaining_variables, sign_guide, delta)

    def stream_merge_based_ablation(
        self,
        retaining_fisher: Sequence[tf.Tensor],
        ablating_fisher: Sequence[tf.Tensor],
        sign_guide: Sequence[tf.Tensor],
        delta: float,
        # ablating_variables: Optional[Sequence[tf.Tensor]] = None,
    ):
        output_model = self.mc.load_model()
        output_variables = hf_util.get_all_variables(output_model)

        # if ablating_variables is None:
        #     ablating_variables = self.get_ablating_variables(sign_guide, delta)

        variables_to_merge = [self.retaining_variables, self.get_ablating_variables(sign_guide, delta)]
        fishers_to_merge = [retaining_fisher, ablating_fisher]

        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

        for coefficients in merging.create_pairwise_grid_coeffs(self.config.n_coefficients):
            merging._merge_with_coeffs(
                output_variables,
                variables_to_merge,
                coefficients=coefficients,
                fishers=fishers_to_merge,
                fisher_floor=self.config.fisher_floor,
                favor_target_model=True,
                normalization_constants=norm_constants,
            )
            yield coefficients, output_model

    def perform_merge_based_ablation_run(
        self,
        retaining_fisher: Sequence[tf.Tensor],
        ablating_fisher: Sequence[tf.Tensor],
        sign_guide: Sequence[tf.Tensor],
        delta: float,
    ):
        output_model = self.mc.load_model()
        output_variables = hf_util.get_all_variables(output_model)

        variables_to_merge = [self.retaining_variables, self.get_ablating_variables(sign_guide, delta)]
        fishers_to_merge = [retaining_fisher, ablating_fisher]

        norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

        results = []
        for coefficients in merging.create_pairwise_grid_coeffs(self.config.n_coefficients):
            merging._merge_with_coeffs(
                output_variables,
                variables_to_merge,
                coefficients=coefficients,
                fishers=fishers_to_merge,
                fisher_floor=self.config.fisher_floor,
                favor_target_model=True,
                normalization_constants=norm_constants,
            )
            coeff_results = []
            for eval_ctx in self.all_eval_contexts:
                coeff_results.append(eval_ctx.evaluate(output_model))
            results.append(
                AblationEvaluationResult(
                    coefficients=coefficients,
                    results=coeff_results,
                )
            )
            results[-1].print_summary(
                sorted_hans_example_indices=self.sorted_example_indices
            )

        return results


###############################################################################


@dataclasses.dataclass
class AblationEvaluationResult:
    coefficients: Sequence[int]
    results: Sequence[QCC.QqpEvaluationResults]


###############################################################################
###############################################################################


"""
- Examine coefficients of top examples and their correlations/relations.
- Try methods and evaluation metrics for union/intersection of components.
"""


def auto_select_delta():
    pass
