import numpy as np
import matplotlib.pyplot as plt
from scipy.fftpack import dctn, idctn
import scipy.stats as stats
plt.rcParams['figure.figsize'] = (10, 8)
# plt.rcParams['image.cmap'] = 'viridis'
from scipy import interpolate
from scipy.integrate import quad
import cvxpy as cp


def dct(a):
  return dctn(a, norm = 'ortho')

def idct(a):
  return idctn(a, norm = 'ortho')

def compute_w2(phi, psi, mu, nu, x):
  n1 = len(mu)
  return np.sum(0.5 * (x*x) * (mu + nu) - nu*psi - mu*phi)/(n1)


def pdf_to_cdf(rho):
  cdf_sum = 0
  cdf = []
  for i in range(len(rho)):
    cdf_sum += rho[i]/sum(rho)
    cdf.append(cdf_sum)
  return np.array(cdf)

def cdf_to_qf(mucdf):
  n = len(mucdf)
  x = np.linspace(0.5/n, 1-0.5/n, n)
  muqf = interpolate.interp1d(mucdf, x, bounds_error=False, fill_value=(x[0], x[-1]))
  uniform_points = np.linspace(0, 1, n)
  muqf_vals = muqf(uniform_points)
  return muqf_vals

def qf_to_cdf(muqf):
    n = len(muqf)

    x = np.linspace(0.5/n, 1-0.5/n, n)

    muqf_interp = interpolate.interp1d(muqf, x, bounds_error=False, fill_value="extrapolate")

    x = np.linspace(0.5/n, 1-0.5/n, n)
    mucdf_vals = muqf_interp(x)

    return mucdf_vals

def projection(f, mu, L):
  length = len(f)
  h = 1 / length
  phi = cp.Variable(length)
  objective = cp.Minimize(cp.sum_squares(f - phi))
  phi_second_derivative = (phi[2:] - 2*phi[1:-1] + phi[:-2]) / h**2
  constraints = [
      phi_second_derivative <= L,
      phi_second_derivative >= mu
  ]
  problem = cp.Problem(objective, constraints)
  try:
      problem.solve(verbose=False)
      if problem.status not in ['infeasible', 'unbounded']:
          return phi.value
      else:
          return f
  except Exception as e:
      print(f"Optimization failed: {e}")
      return f


class OT:
    def __init__(self, func, mu, nu):
        self.func = func
        self.len = len(self.func)
        self.interval = 1/self.len
        self.mass = np.sum(mu)
        self.index = list(range(self.len))
        self.points = np.linspace(0.5/self.len, 1-0.5/self.len, self.len)
        self.mu = mu
        self.nu = nu
        xx = np.linspace(0,np.pi,self.len,False)
        self.kernel = 2*self.len**2*(1-np.cos(xx))
        self.kernel[0] = 1

    def convex(self):
        i = 0
        dual_index = [0]
        while i < self.len-1:
            for j in range(i+1, self.len):
                index_list = self.index.copy()[j+1:-1]
                slope = (self.func[j]-self.func[i])/(self.points[j]-self.points[i])
                intersection = self.func[j] - self.points[j]*slope
                line = slope * self.points[index_list] + intersection
                rubric = self.func[index_list] -line
                if sum(rubric < 0) ==0:
                    dual_index += [j]
                    i = j
                    break
                elif j == self.len-1:
                    dual_index += [j]
                    i = j
                    break
        self.dual_index = dual_index
        return self.dual_index

    def convex_hull(self):
        self.convex()
        interpolation = self.func.copy()
        for i in range(len(self.dual_index)-1):
            id1, id2 = self.dual_index[i], self.dual_index[i+1]
            if id2-id1>1:
                x1, x2 = self.points[id1], self.points[id2]
                val1, val2 = self.func[id1], self.func[id2]
                for j in range(id1,id2):
                    interpolate = (val2-val1)/(x2-x1)*(self.points[j]-x1)+val1
                    interpolation[j] = interpolate
        self.hull = interpolation
        return self.hull


    def visualize_original(self,num_de):
        for i in range((self.len-1)//num_de):
            slope = (self.func[num_de*i+1]-self.func[num_de*i])/(self.points[num_de*i+1]-self.points[num_de*i])
            intersection = self.func[num_de*i+1] - self.points[num_de*i+1]*slope
            line = slope * self.points[self.index] + intersection
            plt.plot(self.points,line ,'r')
        plt.plot(self.points, self.func,'b')
        plt.show()

    def visualize_hull(self):
        self.convex()
        self.convex_hull()
        plt.plot(self.points, self.func,'o')
        plt.plot(self.points, self.hull,'--')
        plt.show()


    def gradient(self, function):
        grad = (function[1]-function[0])/self.interval
        grad_list = [grad]
        for i in range(len(function)-1):
            grad = (function[i+1]-function[i])/self.interval
            grad_list.append(grad)
        return grad_list

    def legendre_transform(self):
        phi = self.convex_hull()
        phi_prime = self.gradient(self.convex_hull())
        last = 0
        legendre = []
        for i in range(self.len):
            for j in range(last, self.len):
                if (self.points[i] - phi_prime[j]) < 0:
                    legendre.append(self.points[i]*self.points[j]-phi[j])
                    last = j
                    break
            if len(legendre) < i+1:
                legendre.append(self.points[i]*self.points[-1]-phi[-1])
        self.legendre = legendre
        return np.array(legendre)

    def visualize(self):
        plt.plot(self.convex_hull())
        self.func = self.legendre_transform()
        legendre = self.legendre_transform()
        plt.plot(legendre,'o')
        plt.show()

    def visualize_legendre(self):
        self.legendre_transform()
        plt.plot(self.points, self.legendre)
        plt.show()

    def displacement(self, func, dist, visualize = False):
        self.mapping = self.gradient(func)
        rho = np.zeros((self.len,1))
        for i in range(self.len-1):
            x_stretch = abs(self.mapping[i+1] - self.mapping[i])
            x_sample = int(max(self.len * x_stretch, 1))
            factor = 1/x_sample*1.0
            for j in range(x_sample):
                a = (j+.5)/(x_sample*1.0)
                xpoint = self.mapping[i+1]*(1-a) + self.mapping[i]*(a)
                X = xpoint*self.len - 0.5
                xindex = int(X) #finding closest position
                xfrac = X - xindex #Remainders
                xother = xindex + 1
                xindex = max(min(xindex,self.len-1),0)
                xother = max(min(xother,self.len-1),0)
                rho[xindex] += (1-xfrac) * dist[i] * factor
                rho[xother] += (xfrac) * dist[i] * factor
        if visualize:
            plt.plot(self.points, dist, 'r', label = 'original')
            plt.plot(self.points, rho,'g', label = 'pushforwarded')
            plt.legend()
            plt.show()
        return np.reshape(rho, self.len)

    def visualize_density_cdf(self):

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))


        ax1.plot(self.points, self.mu, 'r', label = 'mu')
        ax1.plot(self.points, self.nu, 'b', label = 'nu')
        # ax1.plot(self.points, self.rho,'g', label = 'rho')

        ax1.legend()
        ax1.set_title('Density')

        mu_cdf, nu_cdf, rho_cdf = pdf_to_cdf(self.mu), pdf_to_cdf(self.nu), pdf_to_cdf(self.rho)
        domain = [0] + list(self.points)
        ax2.plot(domain, mu_cdf, 'r', label = 'mu')
        ax2.plot(domain, nu_cdf, 'b', label = 'nu')
        # ax2.plot(domain, rho_cdf,'g', label = 'rho')
        ax2.legend()
        ax2.set_title('CDF')

        ax1.plot(self.points, self.rho,'--', label = 'rho')
        ax2.plot(domain, rho_cdf,'--', label = 'rho')

        plt.show()


    def error(self):
        mu_sum, nu_sum, rho_sum = 0, 0, 0
        mu_cdf, nu_cdf, rho_cdf = [], [], []
        for i in range(self.len):
            mu_sum += self.mu[i]/sum(self.mu)
            nu_sum += self.nu[i]/sum(self.nu)
            rho_sum += self.rho[i]/sum(self.rho)
            mu_cdf.append(mu_sum)
            nu_cdf.append(nu_sum)
            rho_cdf.append(rho_sum)
        return np.mean((np.array(rho_cdf)-np.array(nu_cdf))**2)

    def gradient_ascent(self,sigma, n_iter,print_option=True, visualize = False):

        for i in range(n_iter):
            self.rho = self.displacement(self.func, self.mu)
            diff = self.mu-self.displacement(self.legendre_transform(), self.nu)
            workspace = dct(diff) / self.kernel
            workspace[0] = 0
            workspace = idct(workspace)
            self.func -= sigma * workspace
            self.func = self.convex_hull()
            diff = self.nu-self.displacement(self.func, self.mu)
            self.w2 = compute_w2(self.func, self.legendre_transform(), self.mu, self.nu, self.points)

            if print_option:
                if (i+1)%20==0:
                    print(f'iter {i+1:4d},   W2 value: {w2:.6e}, Error : {self.error():.6e}')

            if visualize:
                self.visualize_density_cdf()

def gaussian1d(mean, sigma, length, proj_option, plot_option = True, alpha = 1e-3, beta = 1e+3):
  func1, h = np.linspace(0.5/length, 1-0.5/length, length)**2/2, 1/length
  support = np.linspace(0.5/length, 1-0.5/length, length)
  mu = np.zeros((len(mean), length))
  label_pdf = stats.norm.pdf(support,np.mean(mean),np.mean(sigma))
  label_cdf = stats.norm.cdf(support,np.mean(mean),np.mean(sigma))
  for k in range(len(mu)):
      mu[k] = stats.norm.pdf(support,mean[k],sigma[k])
      mu[k] *= len(mu[k])/sum(mu[k])
  n_iter = 300
  func1 = np.linspace(0.5/length, 1-0.5/length, length)**2/2
  func1 = func1 - np.mean(func1)
  rho = np.ones(length)
  rho *= length/sum(rho)
  phi = np.array([func1]*len(mu))
  w2s = np.zeros((len(mu),n_iter+1))
  lr1 = 5e-2 * np.ones(len(mu))
  for i in range(n_iter):
      prev_phi = phi
      for j in range(len(mu)):
        ascent = OT(phi[j], rho, mu[j])
        ascent.gradient_ascent(lr1[j], n_iter = 1,print_option=False)
        second_derivative = np.gradient(np.gradient(phi[j],h),h)
        if proj_option:
          if np.all((second_derivative >= alpha ) &  (second_derivative <= beta)):
            phi[j] = ascent.func
          else:
            phi[j] = projection(ascent.func, alpha, beta)
        w2s[j,i+1] = ascent.w2
        if w2s[j,i+1] < w2s[j,i]:
          lr1[j] *= 0.99
      ascent = OT(np.mean(prev_phi, axis = 0), rho, mu[j])
      lr2 = np.exp(-(i+1)/n_iter)
      rho = ascent.displacement(func1 - lr2*(func1 - np.mean(phi, axis = 0)), rho)
      if plot_option:
        if (i+1) % n_iter == 0 :
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
            print("n_iter : ", i+1)
            for j in range(len(mu)):
              ax1.plot(support, mu[j], label = 'mu'+str(j))
              mu_cdf = pdf_to_cdf(mu[j])
              ax2.plot(support, mu_cdf, label = 'mu'+str(j))

            ax1.plot(support, label_pdf, label = 'label')
            rho_cdf = pdf_to_cdf(rho)

            ax2.plot(support, rho_cdf,'--', label = 'rho')
            ax2.plot(support, label_cdf, label = 'label')

            ax1.plot(support, rho,'--', label = 'rho')
            ax1.legend()
            ax1.set_title('Density')
            ax2.legend()
            ax2.set_title('CDF')
            plt.show()
  return rho



def l2dist(f,g):
  n = len(f)
  return np.sqrt(np.mean((f-g)**2))

def w2dist(muqf, nuqf):
  n = len(muqf)
  return np.sqrt(np.mean((muqf-nuqf)**2))/n

