import numpy as np
import ot
import torch
import scipy as sp
from scipy.stats import describe
from time import time
import matplotlib.pyplot as plt
from src.metrics import compute_metrics
from tqdm.auto import trange
import wandb
from src.utils import cosine_similarity
import matplotlib.pyplot as plt

from sklearn.neural_network import MLPRegressor
from tqdm import tqdm_notebook as tqdm

from ot import bregman
from ot.gromov import gwggrad, gwloss

from src.models.general_solvers import sinkhorn_knopp


def report_wandb_fn(metrics_dict, metrics_names, epoch, prefix):
    for metric_name in metrics_names:
        wandb.log({f'{prefix}/{metric_name}':metrics_dict[metric_name][-1]}, step=epoch)
            

class AlignGW():

    def __init__(self, cost_name, normalize_dists, loss_fun, eps, tol, metric_names, toy_type=False):
        
        self.cost_name = cost_name
        self.normalize_dists = normalize_dists
        self.loss_fun = loss_fun
        self.eps = eps
        self.tol = tol
        self.ot_warm = True
        self.toy_type = toy_type
        self.metric_names = metric_names

    def compute_distances(self, x, y):
        
        Cx = sp.spatial.distance.cdist(x, x, metric=self.cost_name)
        Cy = sp.spatial.distance.cdist(y, y, metric=self.cost_name)
            
        if self.normalize_dists == 'max':
            Cx /= Cx.max()
            Cy /= Cy.max()
        elif self.normalize_dists == 'mean':
            Cx /= Cx.mean()
            Cy /= Cy.mean()
        elif self.normalize_dists == 'median':
            Cx /= torch.median(Cx)
            Cy /= torch.median(Cy)

        self.Cx, self.Cy = Cx, Cy

    def compute_gamma_entropy(self, G):

        Prod = G * (np.log(G) - 1)
        ent = np.nan_to_num(Prod).sum()

        return ent

    def init_matrix(self, Cx, Cy, T, p, q, loss_fun='square_loss'):

        if loss_fun == 'square_loss':
            def f1(a):
                return (a**2) / 2#in POT this is only a**2, the same for below

            def f2(b):
                return (b**2) / 2

            def h1(a):
                return a

            def h2(b):
                return b
            
        elif loss_fun == 'kl_loss':
            def f1(a):
                return a * np.log(a + 1e-15) - a

            def f2(b):
                return b

            def h1(a):
                return a

            def h2(b):
                return np.log(b + 1e-15)

        constC1 = np.dot(np.dot(f1(Cx), p.reshape(-1, 1)),
                    np.ones(len(q)).reshape(1, -1))
        constC2 = np.dot(np.ones(len(p)).reshape(-1, 1),
                    np.dot(q.reshape(1, -1), f2(Cy).T))
        
        constC = constC1 + constC2
        hCx = h1(Cx)
        hCy = h2(Cy)
        
        return constC, hCx, hCy

    def solve(self, x_dict, y_dict, labels_dict, labels_red_dict, target_vectors, target_vectors_red, maxiter=200, report_every=20):

        x_train, y_train, labels_train, labels_train_red = x_dict['train'], y_dict['train'], labels_dict['train'], labels_red_dict['train']
        x_test, y_test, labels_test, labels_test_red     = x_dict['test'], y_dict['test'], labels_dict['test'], labels_red_dict['test']
        
        x_train_np, y_train_np, labels_train_np, labels_train_red_np = x_train.cpu().numpy(), y_train.cpu().numpy(), labels_train.cpu().numpy(), labels_train_red.cpu().numpy()
        
        p = ot.unif(x_train_np.shape[0])
        q = ot.unif(y_train_np.shape[0])
        
        self.compute_distances(x_train_np, y_train_np)
        metric_names = self.metric_names#['Top@1', 'Top@5', 'Top@10', 'cossim_gt', 'inner_gw', 'foscttm', 'distortion', 'mmd', 'bw_uvp', 'sinkhorn_divergence']

        Cx = self.Cx 
        Cy = self.Cy 
        T = np.outer(p, q)  
        
        constC, hCx, hCy = self.init_matrix(Cx, Cy, T, p, q, self.loss_fun)
        err = 1
        metrics_dict_train = {metric_name:[] for metric_name in metric_names}
        metrics_dict_test = {metric_name:[] for metric_name in metric_names}

        for it in tqdm(range(maxiter)):
            
            Tprev = T
            tens = gwggrad(constC, hCx, hCy, T)

            if self.ot_warm and it > 0:
                #T, log = bregman.sinkhorn_knopp(p, q, tens, self.eps, init_u=log['u'], init_v=log['v'], log=True)
                T, log = sinkhorn_knopp(p, q, tens, self.eps, init_u=log['u'], init_v=log['v'], log=True)
                
            elif self.ot_warm:
                
                T, log = bregman.sinkhorn(p, q, tens, self.eps, log=True)
            else:
                T = bregman.sinkhorn(p, q, tens, self.eps)
                
            if ((it) % report_every == 0 and it != 0) or it == maxiter-1:
                
                err = np.linalg.norm(T - Tprev)
                if err < self.tol:
                    print(f'Converged after {it}...')
                    
                continuous_solver = MLPRegressor(hidden_layer_sizes=256, random_state=1, max_iter=500)
                #breakpoint()
                ids = np.argmax(T, axis=1)
                y_sampled_train_np = y_train_np[ids]
                #y_sampled_train_np = np.argmax()#(T @ y_train_np)/T.sum(axis=1, keepdims=True)
                y_sampled_train = torch.tensor(y_sampled_train_np).to(torch.float32)
                continuous_solver.fit(x_train.cpu().numpy(), y_sampled_train.cpu().numpy())
                
                if self.toy_type is False:
                   
                    #y_sampled_train = torch.randn_like(y_sampled_train)
                    metrics_dict_train = compute_metrics(x_train.cpu(), y_train.cpu(), y_sampled_train.cpu(), labels_train.cpu(), labels_train_red.cpu(), 
                                                         target_vectors, target_vectors_red, metrics_dict_train, 'cosine')
                    
                    report_wandb_fn(metrics_dict_train, metric_names, it, 'train')
                    
                    y_sampled_test = continuous_solver.predict(x_test.cpu().numpy())
                    y_sampled_test = torch.tensor(y_sampled_test).to(torch.float32)
    
                    metrics_dict_test = compute_metrics(x_test.cpu(), y_test.cpu(), y_sampled_test.cpu(), labels_test.cpu(), labels_test_red.cpu(), 
                                                        target_vectors, target_vectors_red, metrics_dict_test, 'cosine')
                    report_wandb_fn(metrics_dict_test, metric_names, it, 'test')
                    
                else:
                    
                    y_sampled_np = (T @ y_train_np)/T.sum(axis=1, keepdims=True)
                
                    fig = plt.figure(figsize=(6, 6))
                    
                    if self.toy_type == 'toy_2d_3d':
                        ax = fig.add_subplot(projection='3d')
                       
                    if self.toy_type == 'toy_3d_2d':
                        ax = fig.add_subplot(projection=None)
                        
                    ax.scatter(*y_sampled_np.T, c=labels_train.cpu().numpy(),  cmap="Spectral", alpha=.8)
                    ax.set_title('AlignGW')
                    
                    plt.show()
            
        return T, continuous_solver, metrics_dict_train, metrics_dict_test
        
    def fit(self, x_dict, y_dict, labels_dict, labels_red_dict, target_vectors, target_vectors_red, metric_names, max_iters, report_every):
        self.max_iter = max_iters
        self.x_dict, self.y_dict, self.labels_dict, self.labels_red_dict = x_dict, y_dict, labels_dict, labels_red_dict
                    
        coupling, continuous_solver, metrics_dict_train, metrics_dict_test  = self.solve(self.x_dict, self.y_dict, self.labels_dict, self.labels_red_dict, target_vectors, target_vectors_red, max_iters, report_every)
        
        #report_wandb_fn(metrics_dict_train, metric_names, self.max_iter, 'train')
        #report_wandb_fn(metrics_dict_test, metric_names, self.max_iter, 'test')

        return metrics_dict_train, metrics_dict_test
       #self.coupling = coupling
       #self.continuous_solver = continuous_solver
    
    #def valid_step(self, sampler_source, sampler_target, n_samples, metric_names, target_vectors, target_vectors_red, n_eval):
    #        
    #    metrics_dict = {metric_name:[] for metric_name in metric_names}
    #    #print(metrics_dict)
    #    with torch.no_grad():
    #    
    #        #sampler_source.reset_sampler()
#
    #        for _ in trange(n_eval, leave=False, desc="Evaluation"):
    #            
    #            if sampler_target is None:
    #                x, y, labels, labels_red = sampler_source.sample(n_samples)
    #            else:
    #                #sampler_target.reset_sampler()
    #                x, labels, labels_red = sampler_source.sample(n_samples)
    #                y, _ , _              = sampler_target.sample(n_samples)
    #                
    #            x, y, labels = x.cpu(), y.cpu(), labels.cpu()
    #            y_sampled = self.continuous_solver.predict(x.numpy())
#
    #            y_sampled = torch.tensor(y_sampled).to(torch.float32)
#
    #            metrics_dict = compute_metrics(x, y, y_sampled, labels, labels_red, target_vectors, target_vectors_red, metrics_dict, 'cosine')
    #            
    #            if sampler_target is None:
    #                report_wandb_fn(metrics_dict, metric_names, self.max_iter, 'train')
    #            else:
    #                report_wandb_fn(metrics_dict, metric_names, self.max_iter, 'test')
    #                
    #        
    #        return metrics_dict
