import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import DataLoader, TensorDataset
import random
import numpy as np
from torch.optim import lr_scheduler
import torch
import copy
import os
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import FontProperties
import math
import argparse

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser()
parser.add_argument('--d_input', type=int, default=50, help='dimension of input, also the dimension of the target distribution')
parser.add_argument('--distribution_type', type=str, default='mixture of gaussian', help='the type of target distribution')
parser.add_argument('--num_mixtures', type=int, default=5, help='total number of mixtures')
parser.add_argument('--mixture_distance', type=int, default=4, help='the distance of the mixtures divided by sqrt(d)')
parser.add_argument('--min_var', type=int, default=0.2, help='the min var of the mixtures')
parser.add_argument('--max_var', type=int, default=0.2, help='the max var of the mixtures')
parser.add_argument('--seed', type=int, default=667, help='Seed the randomness')
parser.add_argument('--corrector', type=str, default='underdamped', help='if we want the overdamped corrector, underdamped corrector or no corrector')
args = parser.parse_args()

seed = args.seed

if args.seed:
    seed = args.seed
    # Set the Python random module seed
    random.seed(seed)

    # Set the numpy random module seed
    np.random.seed(seed)

    # Set the pytorch random number generator seed
    torch.manual_seed(seed)

    # If using CUDA, set the CUDA random number generator seed and enable deterministic operations
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # if using multiple GPUs
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

#distribution type
dt = "mixture of gaussian"

print(f"type of distribution: {args.distribution_type}")


d = args.d_input
T_max = 2000
T_min = 1
total_data = 50000
t_scale = 3

num_recording_steps = 100


def isfloat(num):
    try:
        float(num)
        return True
    except ValueError:
        return False


#generate the mixture of gaussian distribution.
if args.distribution_type == dt:
    class MixtureOfGaussian:
        def __init__(self, args):
            self.args = args
            self.num_mixtures = args.num_mixtures
            self.d = args.d_input
            self.min_var = args.min_var
            self.max_var = args.max_var
            self.mixture_distance = args.mixture_distance
            # generate the means with distance ~ 2 from each other
            self.means = torch.randn(self.num_mixtures, self.d) * self.mixture_distance /2
            #for i in range(self.d):
            #    self.means[0,i] = 3.0
            print("means of the gaussians:")
            print(self.means)
            # generate the covariance matrices as U \Sigma U^T
            self.covs = []
            for i in range(self.num_mixtures):
                # sample the diagonal entries of \Sigma from min_var to max_var
                sigma = torch.diag(torch.rand(self.d) * (self.max_var - self.min_var) + self.min_var)
                #print(sigma)
                # generate a random unitary matrix U
                u, _, _ = torch.svd(torch.randn(self.d, self.d))
                # compute U \Sigma U^T
                cov = u @ sigma @ u.t()
                self.covs.append(cov)
            # generate the distribution vector p supported on num_mixtures
            self.p = torch.rand(self.num_mixtures)
            #self.p[0] = 0.05
            self.p = self.p / self.p.sum() # normalize to sum to 1
            print("probability of the gaussians:")
            print(self.p)

        def sample(self, b_size):
            # this function samples b_size data points
            # sample the mixture indices according to p
            indices = torch.multinomial(self.p, b_size, replacement=True)
            # sample the data from the corresponding mixture of gaussians
            data = torch.zeros(b_size, self.d)
            for i in range(b_size):
                # get the mean and cov of the chosen mixture
                mean = self.means[indices[i]]
                cov = self.covs[indices[i]]
                # sample from the multivariate normal distribution
                data[i] = torch.distributions.MultivariateNormal(mean, cov).sample()
            return data
        

        def compute_posterior_mean(self, X, t):
            import math
            # X is a tensor of shape (b_size, d), representing the current point
            # t is a scalar
            # returns a tensor of shape (b_size, d) representing the gradient of the log-density after time t
            b_size = X.shape[0]
            # compute the normal densities for each mixture and each data point
            # shape: (b_size, num_mixtures)
            densities = torch.zeros(b_size, self.num_mixtures).to(X.device)
            for i in range(self.num_mixtures):
                # compute the mean and cov of the i-th mixture at time t
                mean = math.exp(-t) * self.means[i].to(X.device)
                cov = math.exp(-2 * t) * self.covs[i].to(X.device) + (1 - math.exp(-2 * t)) * torch.eye(self.d).to(X.device)
                # compute the normal density
                dim = cov.shape[0]
                determinant = torch.det(cov)
                constant_term = torch.tensor(2 * torch.pi * torch.e, dtype=torch.float)
                entropy = 0.5 * (dim * torch.log(constant_term) + torch.log(determinant))
                A = torch.distributions.MultivariateNormal(mean, cov).log_prob(X) + entropy
                densities[:, i] = A.exp().to(X.device)

            # compute the weighted sum of the densities
            # shape: (b_size,)
            p_t = (densities * self.p.to(X.device)[None, :]).sum(dim=1)
            # compute the gradient of the normal densities for each mixture and each data point
            # shape: (b_size, num_mixtures, d)
            gradients = torch.zeros(b_size, self.num_mixtures, self.d).to(X.device)
            for i in range(self.num_mixtures):
                # compute the mean and cov of the i-th mixture at time t
                mean = math.exp(-t) * self.means[i].to(X.device)
                cov = math.exp(-2 * t) * self.covs[i].to(X.device) + (1 - math.exp(-2 * t)) * torch.eye(self.d).to(X.device)
                # compute the inverse of the cov
                inv_cov = torch.inverse(cov)
                # compute the gradient of the normal density
                #print( densities[:, i].unsqueeze(1).shape, inv_cov.shape, X.shape, mean.shape)
                A = (X - mean[None, :]) 
                #print(A.shape)
                B =  A@ inv_cov
                #print(B.shape)
                #print( densities[:, i].unsqueeze(1).shape)
                C = densities[:, i][:, None]  * B
                #print(C.shape, gradients[:, i, :].shape )
                gradients[:, i, :] = - C
            # compute the weighted sum of the gradients
            # shape: (b_size, d)
            # print(gradients.shape, self.p.shape)
            nabla_p_t = (gradients * self.p[None, :, None].to(device)).sum(dim=1)
            # compute the gradient of the log-density at the b_size points
            # shape: (b_size, d)
            nabla_log_p_t = nabla_p_t / p_t.unsqueeze(1)
            return nabla_log_p_t

    # create an instance of the class
    mog = MixtureOfGaussian(args)
    
def generate_batch_data(b_size = total_data):
    if args.distribution_type == dt:
        data = mog.sample(b_size)
    
    return data.to(device)


def diffusion_step(x, t, md, lr = t_scale/T_max):
    score_estimate = md(x, t) 
    x = x + lr * (x + score_estimate)
    return x

#Corrector step with overdamped Langevin using the current score estimate at time t, with step size lr_corr for niter iterations
def overdamped_steps(x, t, md, niter, lr_corr = 0.1*t_scale/T_max):
    y = x
    b_size = x.shape[0]
    d = x.shape[1]
    for i in range(niter):
        y = y + lr_corr*md(y,t) + np.sqrt(2*lr_corr)*torch.randn(size = (b_size, d)).to(device) 
    return y

#Corrector step with underdamped Langevin using the current score estimate at time t, with step size lr_corr for niter iterations
def underdamped_steps(x, t, md, niter, gamma = 0.01, u = 0.1, scale = 0.001, lr_corr = 0.1*t_scale/T_max):
    y = x
    b_size = x.shape[0]
    d = x.shape[1]
    v = scale*torch.randn(size = (b_size, d)).to(device) 
    for i in range(niter):
        v = v - gamma*lr_corr*v - u*lr_corr*md(y,t) + np.sqrt(2*lr_corr*gamma*u)*torch.randn(size = (b_size, d)).to(device)
        y = y + v 
    return y

def diffusion(b_size, md):
    with torch.no_grad():
        #Sample from fresh Gaussian (base distribution)
        x = torch.randn(size = (b_size, d)).to(device)
        for t in reversed(range(int(T_min), T_max)):
            print(f"Running reverse process... Step {t} out of {T_max}", end = "\r")
            t_step = (t + 1)/T_max * t_scale
            x = diffusion_step(x, t_step, md = md)
            print(t_step)
            if args.corrector == 'overdamped':
                x = overdamped_steps(x, t_step, md, 10)
            if args.corrector == 'underdamped':
                x = underdamped_steps(x, t_step, md, 3)
        return x


def diffusion_with_recordings(b_size, md, num_recording_steps = num_recording_steps):
    with torch.no_grad():
        x = torch.randn(size = (b_size, d)).to(device)
        records = []
        import math
        old = num_recording_steps
        for t in reversed(range(int(T_min), T_max)):
            t_step = (t + 1)/T_max * t_scale
            x = diffusion_step(x, t_step, md = md)
            if args.corrector == 'overdamped':
                x = overdamped_steps(x, t_step, md, 10)
            if args.corrector == 'underdamped':
                x = underdamped_steps(x, t_step, md, 3)
            # record in log scale
            # record x every T_max//num_recording_steps steps, and at the last step
            current = (math.log(t) - math.log(int(T_min)))/(math.log(T_max)- math.log(int(T_min))) * num_recording_steps
            if current <= old - 1 or t == int(T_min):
                if t == int(T_min):
                    for i in range(20):
                        records.append(x)
                records.append(x)
                old = current - 1
                print(f"Running reverse process... Step {t} out of {T_max}", end = "\r")
            
        return x, records
    



# If you want figures saved in your local directory


# def plot_all_data_test(records, mog):
#     # define a function to update the scatter plot and the ellipses for each frame
#     def update_plot(i):
#         # get the i-th record of data_test
#         data_test = records[i].cpu().numpy()
#         # update the scatter plot with the new data
#         scat.set_offsets(data_test)
#         # return the updated objects
#         return scat,

#     # select 4 equally spaced frames
#     num_frames = len(records)
#     selected_frames = [int(num_frames * i / 4) for i in range(4)]

#     # plot the selected frames
#     output_dir = os.getcwd() #Saves the figures in the current directory
#     for idx, frame in enumerate(selected_frames):
#         # create a figure and an axis object
#         fig, ax = plt.subplots()
#         # set the axis labels and title
#         ax.set_xlabel('x')
#         ax.set_ylabel('y')
#         ax.set_title('Iteration '+str(idx*100))
#         # create a scatter plot object for the data_test
#         scat = ax.scatter([], [], color='blue', label='data_test')
#         # create a list of ellipse objects for the mixture components
#         ellipses = []
#         # loop over the mixture components
#         for i in range(mog.num_mixtures):
#             # get the mean, covariance, and probability of the i-th component
#             mean = mog.means[i]
#             cov = mog.covs[i]
#             prob = mog.p[i]
#             # plot the mean as a red star
#             ax.plot(mean[0], mean[1], 'r*', markersize=10, label=f'mean {i}')
#             # compute the eigenvalues and eigenvectors of the covariance matrix
#             eigvals, eigvecs = torch.eig(cov, eigenvectors=True)
#             # sort the eigenvalues in descending order and get the corresponding indices
#             _, indices = torch.sort(eigvals[:, 0], descending=True)
#             # get the largest and smallest eigenvalues and eigenvectors
#             l1 = eigvals[indices[0], 0]
#             v1 = eigvecs[:, indices[0]]
#             l2 = eigvals[indices[1], 0]
#             v2 = eigvecs[:, indices[1]]
#             # compute the angle of the ellipse using the largest eigenvector
#             angle = torch.atan2(v1[1], v1[0]) * 180 / math.pi
#             # compute the width and height of the ellipse using the eigenvalues
#             width = 2 * math.sqrt(l1)
#             height = 2 * math.sqrt(l2)
#             # create an ellipse object with the mean, width, height, and angle
#             ellipse = matplotlib.patches.Ellipse(mean, width, height, angle, fill=False, edgecolor='red', linewidth=2)
#             # add the ellipse to the axis
#             ax.add_patch(ellipse)
#             # append the ellipse to the list
#             ellipses.append(ellipse)
#             # annotate the probability of the component without the arrow
#             ax.annotate(f'p = {prob:.2f}', xy=(mean[0]+ 0.2, mean[1] + 0.2))
#         plt.grid()
        
#         # adjust the axis limits to the min and max of records
#         xmin = torch.min(torch.cat(records)[:, 0]).cpu().numpy()
#         xmax = torch.max(torch.cat(records)[:, 0]).cpu().numpy()
#         ymin = torch.min(torch.cat(records)[:, 1]).cpu().numpy()
#         ymax = torch.max(torch.cat(records)[:, 1]).cpu().numpy()
#         ax.set_xlim(xmin - 0.5, xmax + 0.5)
#         ax.set_ylim(ymin - 0.5, ymax + 0.5)

#         update_plot(frame)
#         plt.savefig(os.path.join(output_dir, f"frame_{idx+1}_"+args.corrector+"_"+str(T_max)+"_seed"+str(seed)+".eps"))
#         plt.show()



# If you want a video



def plot_all_data_test(records, mog):
  # import matplotlib for plotting
  import matplotlib.pyplot as plt
  import matplotlib
  import math
  # import animation for creating animation
  from matplotlib import animation
  # create a figure and an axis object
  fig, ax = plt.subplots()
  # set the axis labels and title
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  ax.set_title('Data generated by DPUM for mixture of Gaussians')
  # create a scatter plot object for the data_test
  scat = ax.scatter([], [], color='blue', label='data_test')
  # create a list of ellipse objects for the mixture components
  ellipses = []
  # loop over the mixture components
  for i in range(mog.num_mixtures):
    # get the mean, covariance, and probability of the i-th component
    mean = mog.means[i]
    cov = mog.covs[i]
    prob = mog.p[i]
    # plot the mean as a red star
    ax.plot(mean[0], mean[1], 'r*', markersize=10, label=f'mean {i}')
    # compute the eigenvalues and eigenvectors of the covariance matrix
    eigvals, eigvecs = torch.eig(cov, eigenvectors=True)
    # sort the eigenvalues in descending order and get the corresponding indices
    _, indices = torch.sort(eigvals[:, 0], descending=True)
    # get the largest and smallest eigenvalues and eigenvectors
    l1 = eigvals[indices[0], 0]
    v1 = eigvecs[:, indices[0]]
    l2 = eigvals[indices[1], 0]
    v2 = eigvecs[:, indices[1]]
    # compute the angle of the ellipse using the largest eigenvector
    angle = torch.atan2(v1[1], v1[0]) * 180 / math.pi
    # compute the width and height of the ellipse using the eigenvalues
    width = 2 * math.sqrt(l1)
    height = 2 * math.sqrt(l2)
    # create an ellipse object with the mean, width, height, and angle
    ellipse = matplotlib.patches.Ellipse(mean, width, height, angle, fill=False, edgecolor='red', linewidth=2)
    # add the ellipse to the axis
    ax.add_patch(ellipse)
    # append the ellipse to the list
    ellipses.append(ellipse)
    # create an arrow object from the mean to the direction of the largest eigenvector
    #arrow = plt.Arrow(mean[0], mean[1], v1[0], v1[1], width=0.1, color='red')
    # add the arrow to the axis
    #ax.add_patch(arrow)
    # annotate the arrow with the probability of the component
    ax.annotate(f'p = {prob:.2f}', xy=(mean[0], mean[1]), xytext=(mean[0]+ 0.2, mean[1] + 0.2), arrowprops=dict(arrowstyle='->'))
  # show the legend
  #ax.legend()
  plt.grid()
  # define a function to update the scatter plot and the ellipses for each frame
  def animate(i):
    # get the i-th record of data_test
    data_test = records[i].cpu().numpy()
    # update the scatter plot with the new data
    scat.set_offsets(data_test)
    # loop over the mixture components
    for j in range(mog.num_mixtures):
      # get the mean, covariance, and probability of the j-th component
      mean = mog.means[j]
      cov = mog.covs[j]
      prob = mog.p[j]
      # compute the posterior mean of the component given the data_test
      #posterior_mean = mog.compute_posterior_mean(data_test, j)
      # update the position of the ellipse with the posterior mean
      #ellipses[j].center = posterior_mean
    # return the updated objects
    return scat,
  # create an animation object with the figure, the update function, and the number of frames
  anim = animation.FuncAnimation(fig, animate, frames=len(records), interval=100, blit=True)
  # adjust the axis limits to the min and max of records
  xmin = torch.min(torch.cat(records)[:, 0]).cpu().numpy()
  xmax = torch.max(torch.cat(records)[:, 0]).cpu().numpy()
  ymin = torch.min(torch.cat(records)[:, 1]).cpu().numpy()
  ymax = torch.max(torch.cat(records)[:, 1]).cpu().numpy()
  ax.set_xlim(xmin - 0.5, xmax + 0.5)
  ax.set_ylim(ymin - 0.5, ymax + 0.5)
  # show the animation
  plt.show()




#GENERATION
data_test, recordings = diffusion_with_recordings(500,  md = mog.compute_posterior_mean)
print(data_test)
plot_all_data_test(recordings, mog)