import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from parfor import parfor

from tools_weak_proximal import *

import ot


class comparison_maps:
    
     # Initializer / Instance Attributes
    def __init__(self, X_source, Y_target):
        self.x = X_source
        self.y = Y_target
        self.ns = np.shape(self.x)[0]
        self.nt = np.shape(self.y)[0]
        self.a = np.ones((self.ns,)) / self.ns
        self.b = np.ones((self.nt,)) / self.nt
        self.M = ot.dist(self.x, self.y)
        self.plan_OWT = None
        self.S_OWT = None
        self.plan_OWT_QP = None
        self.S_OWT_QP = None
        
        self.plan_OT = None
        self.S_OT = None
        self.plan_OTreg = None
        self.S_OTreg = None
        
        self.objOWT = None
        self.objOT_QP = None
        self.objOT = None
        
        
    def set_parameters(self, gamma, method_proj, batch_size_proj, nb_iter_proj, nb_iter_prox):
        self.gamma = gamma
        self.method_proj = method_proj
        self.batch_size_proj = batch_size_proj
        self.nb_iter_proj = nb_iter_proj
        self.nb_iter_prox = nb_iter_prox
        
        
        
    def compute_OWT(self):
        P0 = np.random.rand(self.ns,self.nt)
        P0 /= np.sum(P0)
        self.plan_OWT = solve_weak_proximal(self.x, self.y, P0, self.a, self.b, self.gamma, self.method_proj, 
                                self.batch_size_proj, self.nb_iter_proj, self.nb_iter_prox)
        self.S_OWT = np.matmul(self.plan_OWT,self.y)/self.a[:,None]

        self.plan_OWT_QP = solve_weak_QP(self.x, self.y, self.a, self.b)
        self.S_OWT_QP = np.matmul(self.plan_OWT_QP,self.y)/self.a[:,None]
        
        
    def compute_OT_OTreg(self, reg = 1):
        self.plan_OT = ot.emd(self.a,self.b,self.M)
        self.S_OT = np.matmul(self.plan_OT,self.y)/self.a[:,None]
        self.plan_OTreg  = ot.sinkhorn(self.a,self.b,self.M, reg) # Entropy regularised OT
        self.S_OTreg = np.matmul(self.plan_OTreg,self.y)/self.a[:,None]
        
    def print_objective(self):
        self.objOT = ot.emd2(self.a, self.b, ot.dist(self.x, self.y))
        self.objOWT = ot.emd2(self.a, self.a, ot.dist(self.x, self.S_OWT))
        self.objOWT_QP = ot.emd2(self.a, self.a, ot.dist(self.x, self.S_OWT_QP))
        print("Objective OT: {}, Objective OWT: {} and Objective OWT via QP: {}".format(self.objOT, self.objOWT, self.objOWT_QP))
        
    def plot_Smaps(self):
        plt.figure(figsize=(15,8))
        plt.subplot(1,2,1)
        plt.plot(self.x[:, 0], self.x[:, 1], '+b', label=r'$\mu$')
        plt.plot(self.y[:, 0], self.y[:, 1], '+r', label=r'$\nu$')
        plt.legend()
        plt.title('Data', fontsize = 20)
        plt.subplot(1,2,2)
        plt.plot(self.x[:, 0], self.x[:, 1], '+b', label=r'$\mu$')
        plt.plot(self.y[:, 0], self.y[:, 1], '+r', label=r'$\nu$')
        plt.scatter(self.S_OTreg[:, 0], self.S_OTreg[:, 1], color = 'magenta',marker = 'x', label=r'$S^{reg}\#\mu$')
        plt.scatter(self.S_OT[:, 0], self.S_OT[:, 1], color = 'orange',marker = '+', label=r'$S^{OT}\#\mu$')
        plt.scatter(self.S_OWT[:, 0], self.S_OWT[:, 1], color = 'black',marker = 'o', label=r'$S^{OWT}\#\mu$')
        plt.scatter(self.S_OWT_QP[:, 0], self.S_OWT_QP[:, 1], color = 'green',marker = 'o', label=r'$S^{OWT}\#\mu$, via QP')

        plt.legend(loc=0, fontsize = 20)
        plt.title('Barycentric projection maps', fontsize = 20)
        
        
class barycenters:
    # Initializer / Instance Attributes
    def __init__(self, example, nb_data, n):
        self.example = example
        self.nb_data = nb_data
        self.n = n
        self.data = []
        self.w = []
        self.wbar = None
        self.barOT = None
        self.barOWT = None
        self.barOWT_QP = None
        self.barOTreg = None
        self.barOT_NS = None
        self.barOWT_NS = None
        self.barOTreg_NS = None
        self.barOWT_QP_NS = None
        self.reg = 1
        
    def set_parameters(self, gamma, method_proj, batch_size_proj, nb_iter_proj, nb_iter_prox):
        self.gamma = gamma
        self.method_proj = method_proj
        self.batch_size_proj = batch_size_proj
        self.nb_iter_proj = nb_iter_proj
        self.nb_iter_prox = nb_iter_prox
        for k in np.arange(self.nb_data):
            self.w.append(np.ones((self.n[k],))/self.n[k])
        
    def set_data(self):
        if self.example == 'gaussian':
            mean = np.array([0, 0])
            cov = np.array([[1, 0], [0, 1]])
            param = np.random.uniform(-5,5,(2, self.nb_data))
            for j in np.arange(self.nb_data):
                self.data.append(ot.datasets.make_2D_samples_gauss(self.n[j], param[:,j], cov))

        elif self.example == "double_ring":
            for j in np.arange(self.nb_data):
                size_ring1 = 4 + np.random.uniform(-1,1,(2,self.nb_data))
                size_ring2 = 10 + np.random.uniform(-1,1,(2,self.nb_data))
                pos_ring = np.random.uniform(-5,5,(2,2))
                rv = np.random.binomial(1,0.5,self.n[j])
                self.data.append(rv[:,None]*(pos_ring[0]+generate_ring(self.n[j],size_ring1[0,j],size_ring1[1,j])) + (1-rv[:,None])*
                    (pos_ring[1]+generate_ring(self.n[j],size_ring1[0,j],size_ring1[1,j])))

    def plot_data(self):
        fig = plt.figure(figsize=(6,6))
        ax = fig.add_subplot(111, aspect='equal')
        plt.title('Data', fontsize = 20)
        for k in np.arange(self.nb_data):
            dd = self.data[k]
            plt.plot(dd[:,0],dd[:,1],'.')
            ax.set_aspect('equal', adjustable='box')
            
    def stochastic_bar(self, method):
        bar = self.data[0]
        self.wbar = self.w[0]
        if method == 'OWT':
            for j in np.arange(self.nb_data-1):
                theta = 1/(2+j)
                new_sample = self.data[j+1]
                wnew = self.w[j+1]
                P0 = np.random.rand(len(self.wbar),len(wnew))
                P0 /= np.sum(P0)
                plan = solve_weak_proximal(bar, new_sample, P0, self.wbar, wnew, self.gamma, 
                    self.method_proj, self.batch_size_proj, self.nb_iter_proj, self.nb_iter_prox)
                Smap = np.matmul(plan, new_sample)/self.wbar[:,None]
                bar = (1-theta)*bar+theta*Smap
            self.barOWT = bar
       
        elif method == 'OWT_QP':
            for j in np.arange(self.nb_data-1):
                theta = 1/(2+j)
                new_sample = self.data[j+1]
                wnew = self.w[j+1]
                plan = solve_weak_QP(bar, new_sample, self.wbar, wnew)
                Smap = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = (1-theta)*bar+theta*Smap
            self.barOWT_QP = bar
                
        elif method == 'OT':
            for j in np.arange(self.nb_data-1):
                theta = 1/(2+j)
                new_sample = self.data[j+1]
                wnew = self.w[j+1]
                M = ot.dist(bar, new_sample)
                plan = ot.emd(self.wbar, wnew, M)
                Smap = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = (1-theta)*bar+theta*Smap
            self.barOT = bar
            
        elif method == 'OTreg':
            for j in np.arange(self.nb_data-1):
                theta = 1/(2+j)
                new_sample = self.data[j+1]
                wnew = self.w[j+1]
                M = ot.dist(bar, new_sample)
                plan = ot.sinkhorn(self.wbar, wnew, M, self.reg)
                Smap = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = (1-theta)*bar+theta*Smap
            self.barOTreg = bar
        
    
    def plot_bar_stochastic(self, show_data, show_bar):   
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, aspect='equal')
        if show_data == 'True':
            for k in np.arange(self.nb_data):
                dd = self.data[k]
                plt.plot(dd[:,0],dd[:,1],'+', alpha = 0.5)
        if 'OT' in show_bar:
            plt.scatter(self.barOT[:,0], self.barOT[:,1], color = 'red',marker = '.', label = 'OT')
        if 'OTreg' in show_bar:    
            plt.scatter(self.barOTreg[:,0], self.barOTreg[:,1], color = 'dodgerblue', marker = '.', label = r'OT Sinkhorn, $\varepsilon={}$'.format(self.reg))
        if 'OWT' in show_bar:
            plt.scatter(self.barOWT[:,0], self.barOWT[:,1], color = 'black',marker = '.', label = 'Weak')
        if 'OWT_QP' in show_bar:
            plt.scatter(self.barOWT_QP[:,0], self.barOWT_QP[:,1], color = 'green',marker = '.', label = 'Weak via QP')
        plt.tick_params(axis='x', labelsize=20)
        plt.tick_params(axis='y', labelsize=20)
        plt.legend(loc = 'upper left', fontsize = 20)
        plt.title('Barycenters computed with stochastic algorithm', fontsize = 20)
        plt.show()


    def non_stochastic_bar(self, method, nb_iter):
        bar = self.data[0]
        self.wbar = self.w[0]
        if method == 'OT':
            print('Method OT')
            for k in np.arange(nb_iter):
                Smap = np.zeros((len(self.wbar),2,self.nb_data))
                for j in np.arange(self.nb_data):
                    new_sample = self.data[j]
                    wnew = self.w[j]
                    M = ot.dist(bar, new_sample)
                    plan = ot.emd(self.wbar, wnew, M)
                    Smap[:,:,j] = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = Smap.mean(2)
                print('Iteration {} out of {} done'.format(k+1,nb_iter))
            self.barOT_NS = bar
       
        if method == 'OWT_QP':
            print('Method OWT via QP')
            for k in np.arange(nb_iter):
                Smap = np.zeros((len(self.wbar),2,self.nb_data))
                for j in np.arange(self.nb_data):
                    new_sample = self.data[j]
                    wnew = self.w[j]
                    M = ot.dist(bar, new_sample)
                    plan = solve_weak_QP(bar, new_sample, self.wbar, wnew)
                    Smap[:,:,j] = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = Smap.mean(2)
                print('Iteration {} out of {} done'.format(k+1,nb_iter))
            self.barOWT_QP_NS = bar 
    
        elif method == 'OTreg':
            print('Method regularised OT')
            for k in np.arange(nb_iter):
                Smap = np.zeros((len(self.wbar),2,self.nb_data))
                for j in np.arange(self.nb_data):
                    new_sample = self.data[j]
                    wnew = self.w[j]
                    M = ot.dist(bar, new_sample)
                    plan = ot.sinkhorn(self.wbar, wnew, M, 1)
                    Smap[:,:,j] = np.matmul(plan, new_sample)/self.wbar[:,None]    
                bar = Smap.mean(2)
                print('Iteration {} out of {} done'.format(k+1,nb_iter))
            self.barOTreg_NS = bar   

    
        elif method == 'OWT':
            print('Method OWT')
            for k in np.arange(nb_iter):
                @parfor(range(self.nb_data))
                def Smap(j):
                    new_sample = self.data[j]
                    wnew = self.w[j]
                    P0 = np.random.rand(len(self.wbar),len(wnew))
                    P0 /= np.sum(P0)
                    plan = solve_weak_proximal(bar, new_sample, P0, self.wbar, wnew, self.gamma, 
                        self.method_proj, self.batch_size_proj, self.nb_iter_proj, self.nb_iter_prox)
                    return np.matmul(plan, new_sample)/self.wbar[:,None]           
                bar = np.mean(Smap,0)
                print('Iteration {} out of {} done'.format(k+1,nb_iter))
            self.barOWT_NS = bar          

    def plot_bar_non_stochastic(self, show_data, show_bar):   
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, aspect='equal')
        if show_data == 'True':
            for k in np.arange(self.nb_data):
                dd = self.data[k]
                plt.plot(dd[:,0],dd[:,1],'+', alpha = 0.5)
        if 'OT' in show_bar:
            plt.scatter(self.barOT_NS[:,0], self.barOT_NS[:,1], color = 'red',marker = '.', label = 'OT')
        if 'OTreg' in show_bar:    
            plt.scatter(self.barOTreg_NS[:,0], self.barOTreg_NS[:,1], color = 'dodgerblue', marker = '.', label = r'OT Sinkhorn, $\varepsilon={}$'.format(self.reg))
        if 'OWT' in show_bar:
            plt.scatter(self.barOWT_NS[:,0], self.barOWT_NS[:,1], color = 'black',marker = '.', label = 'Weak')
        if 'OWT_QP' in show_bar:
            plt.scatter(self.barOWT_QP_NS[:,0], self.barOWT_QP_NS[:,1], color = 'green',marker = '.', label = 'Weak via QP')
        plt.tick_params(axis='x', labelsize=20)
        plt.tick_params(axis='y', labelsize=20)
        plt.legend(loc = 'upper left', fontsize = 20)
        plt.title('Barycenters computed with non-stochastic algorithm', fontsize = 20)
        plt.show()





class weak_map_1D:
    
     # Initializer / Instance Attributes
    def __init__(self, X_source, Y_target):
        self.x = X_source
        self.y = Y_target
        self.ns = len(self.x)
        self.nt = len(self.y)
        self.a = np.ones((self.ns,)) / self.ns
        self.b = np.ones((self.nt,)) / self.nt
        self.plan_OWT = None
        self.S_OWT = None
        self.plan_OWT_QP = None
        self.S_OWT_QP = None
        
        self.objOWT = None
        self.objOWT_QP = None
        self.objOT = None
        
    def set_parameters(self, gamma, method_proj, nb_iter_proj, nb_iter_prox):
        self.gamma = gamma
        self.method_proj = method_proj
        self.nb_iter_proj = nb_iter_proj
        self.nb_iter_prox = nb_iter_prox
        
        
    def compute_OWT_1D(self):
        P0 = np.random.rand(self.ns,self.nt)
        P0 /= np.sum(P0)
        self.plan_OWT, _ = solve_OWT_1D(self.x, self.y, P0, self.a, self.b, self.gamma, self.method_proj, 
                             self.nb_iter_proj, self.nb_iter_prox)
        self.S_OWT = np.sum(self.plan_OWT*self.y,1)/self.a
        self.plan_OWT_QP = solve_weak_QP_1D(self.x, self.y, self.a, self.b)
        self.S_OWT_QP = np.sum(self.plan_OWT_QP*self.y,1)/self.a
        
    def print_objective(self):
        self.objOT = ot.emd2(self.a, self.b, (self.x[:,None] - self.y[:,None].T)**2)
        self.objOWT = ot.emd2(self.a, self.a, (self.x[:,None] - self.S_OWT[:,None].T)**2)
        self.objOWT_QP = ot.emd2(self.a, self.a, (self.x[:,None] - self.S_OWT_QP[:,None].T)**2)
        print("Objective OT: {}, Objective OWT: {}, and Objective OWT via QP: {}".format(self.objOT, self.objOWT, self.objOWT_QP))
        
    def plot_Smaps(self):
        plt.figure(figsize=(7,7))
        plt.hist(self.x, self.ns//2, density=True, color='b', alpha = 0.3, label = r"$\mu$")
        plt.hist(self.y, self.nt//2, density=True, color='r', alpha = 0.3, label = r"$\nu$")
        plt.hist(self.S_OWT, self.ns//2, density=True, color='black', alpha = 0.3, label = r"$S^{OWT}\#\mu$")
        plt.hist(self.S_OWT_QP, self.ns//2, density=True, color='green', alpha = 0.3, label = r"$S^{OWT}\#\mu$ via QP")
        plt.legend(fontsize = 20)
        plt.show()
        
        