import os
import ot

from scipy.spatial.distance import cdist
from src.tools import create_dir
from .tools import visualize_alpha


class ToyEvaluator:
    def __init__(self, data, model):
        self.data = data
        self.model = model
        
    def __call__(self, folder="toy_experiments", show=True):
        a = self.model.a_
        b = self.model.b_
        
        X, Y = self.data.X, self.data.Y

        create_dir(folder)
        
        frotname = os.path.join(folder, "frot_with_noisy_features.png")
        self.data.visualize(self.model.PI_,
                            titlename="FROT on data (with noisy features) ($\eta=1$)",
                            savename=frotname,
                            show=show)

        C = cdist(X, Y)
        pi = ot.sinkhorn(a, b, C, 0.02)
        
        otname = os.path.join(folder, "vanilla_ot_with_noisy_features.png")
        self.data.visualize(pi,
                            titlename="Vanilla OT on data (with noisy features)",
                            savename=otname,
                            show=show)

        C = cdist(X[:, :2], Y[:, :2])
        pi = ot.sinkhorn(a, b, C, 0.02)
        
        bestotname = os.path.join(folder, "vanilla_ot_without_noisy_features.png")
        self.data.visualize(pi,
                            titlename="Vanilla OT on data (without noisy features)",
                            savename=bestotname,
                            show=show)

        print("Alpha: ")
        print("  Data group: {}".format(self.model.alpha_[0]))
        print("  Noise group: {}".format(self.model.alpha_[1]))
        
    def alpha_importance_eta(self, models, etas, folder="toy_experiments", show=True):
        alphafrotname = os.path.join(folder, "frot_alpha_group_importance.png")

        import matplotlib.pyplot as plt
        import numpy as np
        ngroups = len(self.model.alpha_)
        
        ind = np.arange(ngroups)
        plt.figure()

        plt.title('Group importance with FROT')
        plt.ylabel('Group importance')
        # plt.xlabel('Entropic regularization parameter $\eta$')

        nmethod = len(models)
        step = 0.4
        width = 2*step/(nmethod)
        for index, ((alpha, eta), loc) in enumerate(zip(zip([model.alpha_ for model in models], etas), np.linspace(-step, step, nmethod+1)[1:])):
            plt.bar(ind+loc-width/2, alpha, width=width, label='$\eta$={}'.format(eta))

        groups = ['Data group', 'Noise group']
        plt.xticks(ind, groups)
        plt.yticks(np.linspace(0, 1, 11))
        plt.legend(bbox_to_anchor=(0.98, 0.04),
                   borderaxespad=0.,
                   loc='lower right',
                   frameon=True, framealpha=1, fancybox=True)

        plt.savefig(alphafrotname)

        plt.show()    
        plt.close()
