import os
import sys
import glob
import torch
import importlib
import inspect
import logging
import numpy as np
import math
from scipy import optimize
from scipy.special import logsumexp


###############################################################################

class MetaMetrics(type):

    def __get_class_dict(cls):
        class_dict = {}
        for class_name, class_ in inspect.getmembers(
            importlib.import_module("core.metrics"), inspect.isclass
        ):
            if(class_name != "MetaMetrics" and class_name != "Metrics"):
                class_name = class_name.replace("Metrics", "")
                class_dict[class_name] = class_
        return class_dict

    def __call__(cls, *args, **kwargs):
        # Initializing the base classes
        bases = (cls, torch.nn.Module, )

        # Getting the name of the module
        if("name" not in kwargs):
            class_name = args[0]
        else:
            class_name = kwargs["name"]

        # Getting the module dictionnary
        class_dict = cls.__get_class_dict()

        # Checking that the module exists
        if(class_name not in class_dict):
            raise Exception(class_name+" doesn't exist")

        # Adding the new module in the base classes
        bases = (class_dict[class_name], )+bases

        # Creating the new object with the good base classes
        new_cls = type(cls.__name__, bases, {})
        return super(MetaMetrics, new_cls).__call__(*args, **kwargs)


# --------------------------------------------------------------------------- #


class Metrics(metaclass=MetaMetrics):

    def __init__(self, name, model):
        super().__init__()
        self.model = model
        self.param = None

    def numpy_to_torch(self, x, y):
        if(isinstance(x, np.ndarray)):
            x = torch.tensor(x)
        if(isinstance(y, np.ndarray)):
            y = torch.tensor(y)
        return x, y

    def torch_to_numpy(self, x, y, m):
        # Note: m is consider as tensor
        if(isinstance(x, np.ndarray) and isinstance(y, np.ndarray)):
            m = m.detach().numpy()
        return m

    def float_to_numpy_torch(self, x, y, m):
        # Note: m is consider as tensor
        if(isinstance(x, np.ndarray) and isinstance(y, np.ndarray)):
            m = np.array(m)
        elif(isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor)):
            m = torch.tensor(m)
        return m

    def load(self, param):
        if(isinstance(param, torch.Tensor)):
            self.param.data = param.data

    def save(self):
        return self.param

    def fit(self, x, y):
        raise NotImplementedError


# --------------------------------------------------------------------------- #


class BoundedCrossEntropyLossMetrics():

    def __init__(self, name, model, L_max=4.0):
        super().__init__(name, model)
        self.L_max = L_max

    def fit(self, x, y):
        # Computing bounded cross entropy (from Dziugaite et al., 2018)
        x, y = self.numpy_to_torch(x, y)

        exp_L_max = torch.exp(-torch.tensor(self.L_max, requires_grad=False))
        #  x_ = torch.nn.functional.softmax(x, dim=1)
        x_ = exp_L_max + (1.0-2.0*exp_L_max)*x
        x_ = (1.0/self.L_max)*torch.log(x_)
        loss = torch.nn.functional.nll_loss(x_, y[:, 0])

        return self.torch_to_numpy(x, y, loss)


class ZeroOneLossMetrics():

    def fit(self, x, y):
        x, y = self.numpy_to_torch(x, y)
        loss = (x != y).float()
        loss = torch.mean(loss)
        return self.torch_to_numpy(x, y, loss)
