"""Generalization measures based on fisher information"""
import numpy as np
import os
import pickle
import logging
from measures.gen_measures import GenMeasure

from third_party.fisher_utils import StochasticFisher
from scipy.optimize import linear_sum_assignment
from numpy.linalg import norm

class FisherEigValues(GenMeasure):
    _SEED=42
    def __init__(self, *args, num_eig=100, max_num_examples=1000, **kwargs):
        super(FisherEigValues, self).__init__(*args, **kwargs)

        if num_eig >= max_num_examples or num_eig == -1:
            logging.warn("Capping num_eig to max number of examples")
            num_eig = max_num_examples

        self._stochastic_fisher = StochasticFisher(
            seed=self._SEED,
            max_num_examples=max_num_examples,
            num_eig=num_eig, 
        )

    def _compute_train_test_fisher(self):
        if os.path.isfile(self._caching_file_name):
            logging.info(f"Pre-loading computed eigvals and eigvecs from {self._caching_file_name}")
            with open(self._caching_file_name, 'rb') as f:
                payload = pickle.load(f)
            train_eigval = payload['fisher']['train_eigval']
            train_eigvec = np.array(payload['fisher']['train_eigvec']).T.squeeze()
            heldout_eigval = payload['fisher']['heldout_eigval']
            heldout_eigvec = np.array(payload['fisher']['heldout_eigvec']).T.squeeze()
        else:
            logging.info("Computing training eigenvalues and vectors")
            train_eigval, train_eigvec = self._stochastic_fisher._compute(
                self._algorithm,
                self._train_loader,
            )
        
            logging.info("Computing heldout eigenvalues and vectors")
            heldout_eigval, heldout_eigvec = self._stochastic_fisher._compute(
                self._algorithm,
                self._union_held_out_loader,
            )
            logging.info(f"Saving computed eigvals and eigvecs to {self._caching_file_name}")
            with open(self._caching_file_name, 'wb') as f:
                pickle.dump({"fisher": {
                    "train_eigval": train_eigval,
                    "train_eigvec": np.vsplit(train_eigvec.T, train_eigvec.shape[1]),
                    "heldout_eigval": heldout_eigval,
                    "heldout_eigvec": np.vsplit(heldout_eigvec.T, heldout_eigvec.shape[1])
                }}, f)
 
        return train_eigval, train_eigvec, heldout_eigval, heldout_eigvec
 
        
    def _calculate_measure(self):
        train_eigval, _, heldout_eigval, _ = self._compute_train_test_fisher()
        return heldout_eigval.sum()/train_eigval.sum(), {}
   
    
class FisherEigValuesSumDiff(FisherEigValues):
    _SEED=42
    def _calculate_measure(self):
        train_eigval, _, heldout_eigval, _ = self._compute_train_test_fisher()
        return heldout_eigval.sum() - train_eigval.sum(), {}


class FisherEigVecAlign(FisherEigValues):
    _SEED=42
    def _calculate_measure(self):
        _, train_eigvec, _, heldout_eigvec = self._compute_train_test_fisher()
       
        # L2 normalize train_eigvec
        train_eigvec = train_eigvec.T
        train_eigvec /= np.expand_dims(norm(train_eigvec, axis=-1), 1)
        heldout_eigvec = heldout_eigvec.T
        heldout_eigvec /= np.expand_dims(norm(heldout_eigvec, axis=-1), 1)
        
        similarity_matrix = heldout_eigvec.dot(train_eigvec.T)
        cost_matrix = -1 * similarity_matrix
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        
        return similarity_matrix[row_ind, col_ind].sum(), {}