import torch
from tqdm.auto import trange
from tqdm import tqdm_notebook as tqdm
from src.metrics import compute_metrics
import wandb
import random
import numpy as np
import jax
from functools import partial
import ot
import jax.numpy as jnp
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from tqdm.notebook import tqdm
from ott.geometry.costs import CostFn

from ott.geometry.pointcloud import PointCloud
import jax.tree_util as jtu

from typing import Optional
import torch
from tqdm.auto import trange
from tqdm import tqdm_notebook as tqdm
from src.metrics import compute_metrics
import wandb
import random
import numpy as np
import jax
from functools import partial

import jax.numpy as jnp
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from tqdm.notebook import tqdm
from ott.geometry.costs import CostFn

from ott.geometry.pointcloud import PointCloud

from typing import Optional
from typing import Tuple

import jax
import jax.numpy as jnp

from jaxopt._src import tree_util
from typing import Any

from jax import tree_util as tu
from sklearn.neural_network import MLPRegressor
import matplotlib.pyplot as plt

tree_map = tu.tree_map

def prox_lasso(x: Any,
               l1reg: Optional[Any] = None,
               scaling: float = 1.0) -> Any:

    if l1reg is None:
        l1reg = 1.0

    if type(l1reg) == float:
        l1reg = tree_util.tree_map(lambda y: l1reg*jnp.ones_like(y), x)

    def fun(u, v): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)
    return tree_util.tree_map(fun, x, l1reg)


def update_M(X, Y, method="exact", solver_state=None, reg=0.5):
    if method == "exact":
        PY = (solver_state.matrix@Y).T#solver_state.apply(Y.T, axis=1)
        M = PY.dot(X)
    elif method == "l1_reg":
        PY = (solver_state.matrix@Y).T#solver_state.apply(Y.T, axis=1)
        M = PY.dot(X)
        M = prox_lasso(M, l1reg=reg)
    #elif method == "l12_reg":
    #    PY = solver_state.matrix@Y#solver_state.apply(Y.T, axis=1)
    #    M = PY.dot(X)
    #    M = L21columns().prox(M, gamma=kwargs["l12_reg"])

    return M
def report_wandb_fn(metrics_dict, metrics_names, epoch, prefix):
    for metric_name in metrics_names:
        wandb.log({f'{prefix}/{metric_name}':metrics_dict[metric_name][-1]}, step=epoch)

@jax.jit
def f_eps(x, Y, g, beta, eps):
    a = (g[None, :] + x.dot(Y.T)) / eps
    return - eps * jax.nn.logsumexp(a=a, b=beta)

    
def beta_tilde_fn(Y, g, beta, eps):
    @jax.jit
    def beta_tilde_sample(x):
        return - jax.grad(f_eps, argnums=2)(x, Y, g, beta, eps)

    return beta_tilde_sample


@jax.jit
def beta_tilde_vec(x, Y, g, beta, eps):
    fun = beta_tilde_fn(Y, g, beta, eps)
    return jax.jit(jax.vmap(fun))(x)
    

def entropic_pred(X, Y, M, g, beta, eps):
    betas = beta_tilde_vec(X, Y.dot(M), g, beta, eps)
    return betas.dot(Y)

@jax.tree_util.register_pytree_node_class
class GWCost_IP(CostFn):
    def __init__(self):
        super().__init__()

    def pairwise(self, x: jnp.ndarray, y: jnp.ndarray):
        return - jnp.vdot(x, y)

import torch

def f_eps_torch(x, Y, g, beta, eps):
    a = (g.unsqueeze(0) + x @ Y.T) / eps
    return -eps * torch.logsumexp(a + beta, dim=1)

def beta_tilde_fn_torch(Y, g, beta, eps):
    def beta_tilde_sample(x):
        x = x.requires_grad_(True)
        fx = f_eps_torch(x.unsqueeze(0), Y, g, beta, eps).sum()
        grad_g = torch.autograd.grad(fx, g, create_graph=False)[0]
        return -grad_g
    return beta_tilde_sample

def beta_tilde_vec_torch(X, Y, g, beta, eps):
    g = g.detach().requires_grad_(True)
    fun = beta_tilde_fn_torch(Y, g, beta, eps)
    return torch.stack([fun(x) for x in X])

def entropic_pred_torch(X, Y, M, g, beta, eps):
    YM = Y @ M
    betas = beta_tilde_vec_torch(X, YM, g, beta, eps)
    return betas @ Y

#@partial(jax.jit, static_argnums=(8,))
#def generate_preds(X_train, Y_train, X_test, M, g, beta, eps, batch_size):
#    X_train1 = X_train[:(X_train.shape[0] // batch_size) * batch_size, :]
#    X_train1 = X_train1.reshape((X_train1.shape[0] // batch_size, batch_size, X_train1.shape[1]))
#    X_train2 = X_train[(X_train.shape[0] // batch_size) * batch_size:]
#    genes_test_pred = jnp.zeros(shape=(X_train1.shape[0], batch_size, X_test.shape[1]))
#    
#    for ix in range(X_train1.shape[0]):
#        betas = beta_tilde_vec(X_train[ix], Y_train.dot(M), g, beta, eps)
#        new_pred_test = betas.dot(X_test)
#        genes_test_pred = genes_test_pred.at[i].set(new_pred_test)
#
#    genes_test_pred = genes_test_pred.reshape(X_train1.shape[0] * batch_size, X_test.shape[1])
#    last_pred_test = beta_tilde_vec(X_train2, Y_train.dot(M), g, beta, eps).dot(X_test)
#    genes_test_pred = jnp.concatenate((genes_test_pred, last_pred_test))
#    return genes_test_pred

def generate_preds(X_test, Y_train, M, g, beta, eps, batch_size):
    # Embed Y_train
    Y_emb = Y_train @ M
    test_batches = X_test.shape[0] // batch_size
    X_test1 = X_test[:test_batches * batch_size].reshape((test_batches, batch_size, X_test.shape[1]))
    X_test2 = X_test[test_batches * batch_size:]

    preds = []

    for i in range(test_batches):
        x_batch = X_test1[i]
        beta_batch = beta_tilde_vec(x_batch, Y_emb, g, beta, eps)
        pred_batch = beta_batch @ Y_train
        preds.append(pred_batch)

    preds = jnp.concatenate(preds, axis=0)

    if X_test2.shape[0] > 0:
        beta_batch = beta_tilde_vec(X_test2, Y_emb, g, beta, eps)
        pred_batch = beta_batch @ Y_train
        preds = jnp.concatenate([preds, pred_batch], axis=0)

    return preds

    
@partial(jax.jit, static_argnums=(4,))
def update_geometry(X, Y, M, geom_epsilon, geom_batch_size,):
    z_x = X.dot(M.T)
    z_y = Y
    cost_fn = GWCost_IP()
    return PointCloud(x=z_x, y=z_y, epsilon=geom_epsilon, batch_size=geom_batch_size, cost_fn=cost_fn)

def InnerProduct(x, y):
    #out = (x.flatten(start_dim=1) * y.flatten(start_dim=1)).sum(dim=1)
    out =  x @ y.T
    return -out

def update_M_torch(X, Y, solver_matrix):
    PY = (solver_matrix@Y).T#solver_state.apply(Y.T, axis=1)
    M = PY @ X

    return M

class StructuredGW():
    def __init__(self, M_init, method_M, eps=1e-4, tol=1e-3, toy_type=False, metric_names=None, device='cuda'):
        self.M_init = M_init
        self.method_M = 'exact'
        self.eps = eps
        self.tol = tol
        self.toy_type = toy_type
        self.metric_names = metric_names
        self.device = device

    def solve_torch(self, x_dict, y_dict, labels_dict, labels_red_dict, target_vectors, target_vectors_red, maxiter=200, report_every=20):

        iters = 1000
        x_train, y_train, labels_train, labels_train_red = x_dict['train'], y_dict['train'], labels_dict['train'], labels_red_dict['train']
        x_test, y_test, labels_test, labels_test_red     = x_dict['test'], y_dict['test'], labels_dict['test'], labels_red_dict['test']

        x_train = x_train.to(self.device)
        y_train = y_train.to(self.device)

        x_test = x_test.to(self.device)
        y_test = y_test.to(self.device)
        
        M = torch.ones((y_train.shape[1], x_train.shape[1])) / max(y_train.shape[1], x_train.shape[1])
        M = M.to(self.device)
        
        init_dual_a = torch.zeros(3000).to(self.device)
        init_dual_b = torch.zeros(3000).to(self.device)
        
        a = torch.ones(x_train.shape[0], device=self.device) / x_train.shape[0]
        b = torch.ones(y_train.shape[0], device=self.device) / y_train.shape[0]
        
        Mx = x_train @ M.T
        C = InnerProduct(Mx, y_train)

        metric_names = self.metric_names

        metrics_dict_train = {metric_name:[] for metric_name in metric_names}
        metrics_dict_test = {metric_name:[] for metric_name in metric_names}

        #costs = - jnp.ones(max_iter)

        for it in tqdm(range(maxiter)):

            coupling, log = ot.sinkhorn(
                a,
                b,
                C,
                self.eps,
                stopThr=1e-3,
                method='sinkhorn_log',
                log=True,
                numItermax=iters,
                u = init_dual_a,
                v = init_dual_b
                
            )
        
            f, g = self.eps * log['log_u'], self.eps * log['log_v']
            init_dual_a, init_dual_b = f, g
        
            M = update_M_torch(x_train, y_train, coupling)
            Mx = x_train @ M.T
            C = InnerProduct(Mx, y_train)
        
            coupling, log = ot.sinkhorn(
                a,
                b,
                C,
                self.eps,
                stopThr=1e-3,
                method='sinkhorn_log',
                log=True,
                numItermax=iters,
                u = init_dual_a,
                v = init_dual_b
                
            )
        
            f, g = self.eps * log['log_u'], self.eps * log['log_v']
            
            init_dual_a, init_dual_b = f, g  

            if (it % report_every == 0 and it != 0) or it == maxiter-1:
               
                continuous_solver = MLPRegressor(hidden_layer_sizes=256, random_state=1, max_iter=500)
                
                if self.toy_type is False:
                    print('Iteration: ',it)
                    y_sampled_train = entropic_pred_torch(x_train, y_train, M, g, b, self.eps)
                    
                    #y_sampled_train = torch.tensor(np.asarray(y_sampled_train)).to(torch.float32)
                    metrics_dict_train = compute_metrics(x_train.cpu(), y_train.cpu(), y_sampled_train.cpu(), labels_train.cpu(), labels_train_red.cpu(), target_vectors, target_vectors_red, metrics_dict_train, 'cosine')
                    
                    report_wandb_fn(metrics_dict_train, metric_names, it, 'train')
                    continuous_solver.fit(x_train.cpu().numpy(), y_sampled_train.cpu().numpy())
                    y_sampled_test = continuous_solver.predict(x_test.cpu().numpy())
                    #y_sampled_test = np.asarray(generate_preds(x_test.cpu().numpy(), y_train.cpu().numpy(), M, solver_state.g, b, self.eps, 2048))
                    
                    y_sampled_test = torch.tensor(y_sampled_test).to(torch.float32)
                    
                    metrics_dict_test = compute_metrics(x_test.cpu(), y_test.cpu(), y_sampled_test.cpu(), labels_test.cpu(), labels_test_red.cpu(), 
                                                        target_vectors, target_vectors_red, metrics_dict_test, 'cosine')
                    report_wandb_fn(metrics_dict_test, metric_names, it, 'test')
                else:
                    T = np.asarray(solver_state.matrix)
                    y_sampled_np = (T @ y_train.cpu().numpy())/T.sum(axis=1, keepdims=True)
                    #y_sampled = torch.tensor(y_sampled).to(torch.float32)
                
                    fig = plt.figure(figsize=(8, 8))
                    
                    if self.toy_type == 'toy_2d_3d':
                        ax = fig.add_subplot(projection='3d')
                       
                    if self.toy_type == 'toy_3d_2d':
                        ax = fig.add_subplot(projection=None)
                        
                    ax.scatter(*y_sampled_np.T, c=labels_train.cpu().numpy(),  cmap="Spectral")
                    ax.set_title('StructuredGW')
                    
                    plt.show()


        return coupling, g, M, b, continuous_solver, metrics_dict_train, metrics_dict_test
        
    def solve(self, x_dict, y_dict, labels_dict, labels_red_dict, target_vectors, target_vectors_red, maxiter=200, report_every=20):
        
        x_train, y_train, labels_train, labels_train_red = x_dict['train'], y_dict['train'], labels_dict['train'], labels_red_dict['train']
        x_test, y_test, labels_test, labels_test_red     = x_dict['test'], y_dict['test'], labels_dict['test'], labels_red_dict['test']
        
        x_train_jnp, y_train_jnp = np.asarray(x_train.cpu().numpy()), np.asarray(y_train.cpu().numpy())
        x_train_jnp, y_train_jnp= jtu.tree_map(jnp.asarray, x_train_jnp), jtu.tree_map(jnp.asarray, y_train_jnp)
        
        
        if self.M_init is None:
            M = jnp.ones((y_train.shape[1], x_train.shape[1])) / jnp.maximum(y_train.shape[1], x_train.shape[1])
        else:
            M = self.M_init

        a = jnp.ones(x_train.shape[0]) / x_train.shape[0]
        b = jnp.ones(y_train.shape[0]) / y_train.shape[0]

        solver = sinkhorn.Sinkhorn(
            threshold=self.tol,
            max_iterations=1000,
            norm_error=2,
            lse_mode=True,
        )

        rng = jax.random.PRNGKey(0)
        initializer = solver.create_initializer()
        geom = update_geometry(X=x_train_jnp, Y=y_train_jnp, M=M, geom_batch_size=None, geom_epsilon=self.eps)
        prob = linear_problem.LinearProblem(geom, a=a, b=b)
        init_dual_a, init_dual_b = initializer(prob, *(None, None), lse_mode=True, rng=rng)

        metric_names = self.metric_names

        metrics_dict_train = {metric_name:[] for metric_name in metric_names}
        metrics_dict_test = {metric_name:[] for metric_name in metric_names}

        #costs = - jnp.ones(max_iter)

        for it in tqdm(range(maxiter)):

            solver_state = solver(prob, (init_dual_a, init_dual_b))

            init_dual_a, init_dual_b = solver_state.f, solver_state.g


            M = update_M(x_train_jnp, y_train_jnp, method=self.method_M, solver_state=solver_state, reg=self.eps)

            geom = update_geometry(X=x_train_jnp, Y=y_train_jnp, M=M, geom_epsilon=geom.epsilon,
                                       geom_batch_size=geom.batch_size)

            prob = linear_problem.LinearProblem(geom, a=a, b=b)

            solver_state = solver(prob, (init_dual_a, init_dual_b))
            init_dual_a, init_dual_b = solver_state.f, solver_state.g  
            coupling_matrix = solver_state.matrix
            #y_sampled_test = generate_preds(x_train.cpu().numpy(), y_train.cpu().numpy(), x_test.cpu().numpy(), M, solver_state.g, b, self.eps, x_test.shape[0])
            if (it % report_every == 0 and it != 0) or it == maxiter-1:
                #y_sampled1 = (coupling_matrix @ Y) / coupling_matrix.sum(axis=1, keepdims=True)
                #y_sampled1 = torch.tensor(np.asarray(y_sampled1)).to(torch.float32)
                
                #y_sampled_train = entropic_pred(x_train.cpu().numpy(), y_train.cpu().numpy(), M, solver_state.g, b, eps)
                #y_sampled_train = torch.tensor(np.asarray(y_sampled_train)).to(torch.float32)
#
                #metrics_train_dict = {metric:[] for metric in metric_names}
                #metrics_train_dict = compute_metrics(x_train, y_train, y_sampled, torch.tensor(labels), target_vectors, metrics_train_dict2)
                
                continuous_solver = MLPRegressor(hidden_layer_sizes=256, random_state=1, max_iter=500)
                
                if self.toy_type is False:
                    print('Iteration: ',it)
                    y_sampled_train = entropic_pred(x_train_jnp, y_train_jnp, M, solver_state.g, b, self.eps)
                    #M_torch = torch.tensor(np.asarray(M)).to(torch.float32)
                    #g_torch = torch.tensor(np.asarray(solver_state.g)).to(torch.float32)
                    #b_torch = torch.tensor(np.asarray(b)).to(torch.float32)
                    
                    #y_sampled_train = entropic_pred(x_train.cpu().numpy(), y_train.cpu().numpy(), M, solver_state.g, b, self.eps)
                    #y_sampled_train = entropic_pred_torch(x_train, y_train, M_torch, g_torch, b_torch, self.eps)
                    y_sampled_train = torch.tensor(np.asarray(y_sampled_train)).to(torch.float32)
                    metrics_dict_train = compute_metrics(x_train.cpu(), y_train.cpu(), y_sampled_train.cpu(), labels_train.cpu(), labels_train_red.cpu(), 
                                                         target_vectors, target_vectors_red, metrics_dict_train, 'cosine')
                    
                    report_wandb_fn(metrics_dict_train, metric_names, it, 'train')
                    continuous_solver.fit(x_train.cpu().numpy(), y_sampled_train.cpu().numpy())
                    y_sampled_test = continuous_solver.predict(x_test.cpu().numpy())
                    #y_sampled_test = np.asarray(generate_preds(x_test.cpu().numpy(), y_train.cpu().numpy(), M, solver_state.g, b, self.eps, 2048))
                    
                    y_sampled_test = torch.tensor(y_sampled_test).to(torch.float32)
                    
                    metrics_dict_test = compute_metrics(x_test.cpu(), y_test.cpu(), y_sampled_test.cpu(), labels_test.cpu(), labels_test_red.cpu(), 
                                                        target_vectors, target_vectors_red, metrics_dict_test, 'cosine')
                    report_wandb_fn(metrics_dict_test, metric_names, it, 'test')
                else:
                    T = np.asarray(solver_state.matrix)
                    y_sampled_np = (T @ y_train.cpu().numpy())/T.sum(axis=1, keepdims=True)
                    #y_sampled = torch.tensor(y_sampled).to(torch.float32)
                
                    fig = plt.figure(figsize=(8, 8))
                    
                    if self.toy_type == 'toy_2d_3d':
                        ax = fig.add_subplot(projection='3d')
                       
                    if self.toy_type == 'toy_3d_2d':
                        ax = fig.add_subplot(projection=None)
                        
                    ax.scatter(*y_sampled_np.T, c=labels_train.cpu().numpy(),  cmap="Spectral")
                    ax.set_title('StructuredGW')
                    
                    plt.show()


        return coupling_matrix, solver_state.g, M, b, continuous_solver, metrics_dict_train, metrics_dict_test
            #y_pred = (coupling_matrix @ Y) / coupling_matrix.sum(axis=1, keepdims=True)
        
    def fit(self, x_dict, y_dict, labels_dict, labels_red_dict, target_vectors, target_vectors_red, metric_names, max_iters, report_every):
        
        self.x_dict, self.y_dict, self.labels_dict, self.labels_red_dict = x_dict, y_dict, labels_dict, labels_red_dict
        #self.coupling, self.g, self.M, self.b = train(self.x_train, self.y_train, labels, target_vectors, None, max_iter=max_iters, eps=self.eps)
        
        coupling, g, M, b, continuous_solver, metrics_dict_train, metrics_dict_test  = self.solve(self.x_dict, self.y_dict, self.labels_dict, self.labels_red_dict, target_vectors, target_vectors_red, max_iters, report_every)
        
        return metrics_dict_train, metrics_dict_test
        #self.coupling = coupling
        #self.M = M
        #self.continuous_solver = continuous_solver
        #self.g = g
        #self.b = b
        
        
    #def valid_step(self, sampler_source, sampler_target, n_samples, metric_names, target_vectors, n_eval):
    #        
    #    metrics_dict = {metric_name:[] for metric_name in metric_names}
    #    y_train =  self.y_dict['train']
    #    with torch.no_grad():
    #    
    #        #sampler_source.reset_sampler()
#
    #        for _ in trange(n_eval, leave=False, desc="Evaluation"):
    #            
    #            if sampler_target is None:
    #                x, y, labels = sampler_source.sample(n_samples)
    #            else:
    #                #sampler_target.reset_sampler()
    #                x, labels = sampler_source.sample(n_samples)
    #                y, _      = sampler_target.sample(n_samples)
    #                
    #            x, y, labels = x.cpu(), y.cpu(), labels.cpu()
    #            y_sampled = self.continuous_solver.predict(x.numpy())
    #            #y_sampled = np.asarray(generate_preds(x.cpu().numpy(), y_train.cpu().numpy(), self.M, self.g, self.b, self.eps, 2048))
#
    #            y_sampled = torch.tensor(y_sampled).to(torch.float32)
#
    #            metrics_dict = compute_metrics(x, y, y_sampled, labels, target_vectors, metrics_dict)
    #        
    #        return metrics_dict

    