import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import argparse
import logging
from scipy import integrate
from scipy.linalg import sqrtm
from scipy import integrate
import copy
from scipy import integrate
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn
import torch.nn.functional as F


def marginal_prob_std(t, sigma):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  t = t.to(device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  return sigma**t.to(device)


class GaussianMixtureDataset(Dataset):
    def __init__(self, n_samples, means, covs, weights):
        self.n_samples = n_samples
        self.means = means
        self.covs = covs
        self.weights = weights
        self.n_components = len(means)
        self.data, self.labels = self._generate_data()

    def _generate_data(self):
        data = []
        labels = []
        
        comp_indices = np.random.choice(self.n_components, size=self.n_samples, p=self.weights)
        for i in range(self.n_components):
           uniques, counts = np.unique(comp_indices, axis=0, return_counts=True)
           #print(uniques, counts)
        for comp in comp_indices:
            sample = np.random.multivariate_normal(self.means[comp], self.covs[comp])
            data.append(sample)
            labels.append(comp)
        return np.array(data, dtype=np.float32), np.array(labels, dtype=np.int64)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
  
class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)
    
class ScoreNet_1HiddenLayerFC(nn.Module):
  """A time-dependent score-based model built upon 1-hidden-layer fully-connected NN architecture."""

  def __init__(self, d, marginal_prob_std, sigma, hidden_dim=16, embed_dim=4):
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, hidden_dim))
    self.dense1 = Dense(d, hidden_dim)
    self.dense2 = Dense(hidden_dim, d)

    self.act = lambda x: x * torch.relu(x)
    self.marginal_prob_std = marginal_prob_std
    self.sigma = sigma
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.embed(t) 
    # Feature map
    h = self.dense1(x)
    h += embed
    h = self.dense2(self.act(h))
    h = h / 16

    # Normalize output
    h = h / marginal_prob_std(t, self.sigma)[:, None]
    return h
    
class Diffusion(nn.Module):
    def __init__(
        self,
        eps_model: nn.Module,
        n_T: int,
        sigma: float
    ) -> None:
        super(Diffusion, self).__init__()
        self.eps_model = eps_model
        self.n_T = n_T
        self.sigma = sigma
    
    def score(self, x_t: torch.Tensor, _ts) -> torch.Tensor:
        return self.eps_model(x_t, _ts)

    def forward(self, x: torch.Tensor, eps=1e-5) -> torch.Tensor:
        
        random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
        z = torch.randn_like(x)
        std = marginal_prob_std(random_t, self.sigma)
        perturbed_x = x + z * std[:, None]
        score = self.eps_model(perturbed_x, random_t)
        loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=(1)))
        return loss, perturbed_x, random_t

    
def sample_means(n_components, d, bound=5.0, max_iters=100000):

    means = []
    iters = 0

    while len(means) < n_components and iters < max_iters:
        cand = np.random.uniform(-bound, bound, size=d)
        if all(np.linalg.norm(cand - m) >= np.sqrt(d) for m in means):
            means.append(cand)
        iters += 1

    if len(means) < n_components:
        raise ValueError(
            f"Only found {len(means)}/{n_components} means after {max_iters} trials. "
            "Try increasing `bound` or reducing `c0`."
        )

    return np.vstack(means)



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--d', type=int, default=1)
    parser.add_argument('--n_components', type=int, default=2)
    parser.add_argument('--n_samples', type=int, default=30)
    parser.add_argument('--n_epochs', type=int, default=5000)
    parser.add_argument('--tau', type=float, default=0.01)
    parser.add_argument('--alpha', type=float, default=2.5)
    parser.add_argument('--sigma', type=float, default=25.0)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--path', type=str, default='./contents/defalt')
    parser.add_argument('--log_path_name', type=str, default='defalt')
    parser.add_argument('--seed', type=int, default=1234, help='Random seed')

    args = parser.parse_args()

    d = args.d
    n_components = args.n_components
    n_samples = args.n_samples
    num_epochs = args.n_epochs
    tau = args.tau
    alpha = args.alpha
    sigma = args.sigma
    learning_rate = args.lr
    path = args.path
    
    log_path_name = args.log_path_name
    log_path = './log/' + log_path_name

    ckpt_dir="checkpoints/"
    os.makedirs(ckpt_dir + log_path_name, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    logging.basicConfig(
        filename=log_path, 
        filemode='a', 
        level=logging.INFO,
        format='%(asctime)s %(message)s'
    )
    logger = logging.getLogger('train')

    logger.info('-' * 19)
    for k, v in vars(args).items():
        logger.info("%s = %r", k, v)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    means = [[-5], [5]]
    covs = [3*np.eye(1), 3*np.eye(1)]
    #covs = [np.eye(d) for _ in range(n_components)]
    weights = np.arange(1, n_components+1) ** (-alpha)
    weights /= weights.sum()

    # Create dataset 
    dataset = GaussianMixtureDataset(n_samples, means, covs, weights)
    dataset, labels = dataset._generate_data()
    
    '''
    means_mix = [[-4], [4]]
    covs_mix = [3*np.eye(1), 3*np.eye(1)]
    weights_mix = [0.5, 0.5]

    mean = [0]
    cov = 2*np.eye(1)

    weights = np.arange(1, n_components+1) ** (-alpha)
    weights /= weights.sum()

    n_samples_mix = int(np.ceil(n_samples*weights[0]))
    dataset_mix = GaussianMixtureDataset(n_samples_mix, means_mix, covs_mix, weights_mix)
    X_mix, _    = dataset_mix._generate_data()          # shape (n_samples, 1)
    y_mix       = np.zeros(X_mix.shape[0], dtype=int)

    X_gauss     = np.random.multivariate_normal(
                  mean,
                  cov,
                  size=n_samples-n_samples_mix
              )                                  # shape (n_samples, 1)
    y_gauss     = np.ones(X_gauss.shape[0], dtype=int)

    dataset = np.vstack([X_mix,   X_gauss])    # shape (2*n_samples, 1)
    labels = np.concatenate([y_mix, y_gauss])'''

    dms = {}
    optimizers = {}
    min_norms = {}
    min_grad_norms = []

    for comp in range(n_components):
        dms[comp] = Diffusion(eps_model=ScoreNet_1HiddenLayerFC(d, marginal_prob_std=1.2, sigma=sigma), n_T=100, sigma=sigma).to(device)
        optimizers[comp] = optim.Adam(dms[comp].parameters(), lr=learning_rate)
        min_norms[comp] = 0
        dms[comp].eval()

    for comp in range(n_components, 2*n_components):
        dms[comp] = copy.deepcopy(dms[comp - n_components])
        optimizers[comp] = optim.Adam(dms[comp].parameters(), lr=learning_rate)
        dms[comp].eval()

    # Vanilla Training
    for epoch in range(num_epochs):
        for label_int in range(n_components):
            mask = (labels == label_int)
            X_batch = dataset[mask]
            data = torch.tensor(X_batch).float().to(device)

            dms[label_int + n_components].train()
            
            loss_v, x_noise_v, t_noise_v = dms[label_int + n_components](data)

            optimizers[label_int + n_components].zero_grad()
            loss_v.backward()
            optimizers[label_int + n_components].step()
            
            ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}_vanilla.pth")
            torch.save(dms[label_int + n_components].state_dict(), ckpt_path)

            dms[label_int + n_components].eval()

        #print(f"Epoch {epoch+1}/{num_epochs}, vanilla loss={loss_v.item():.4f}")

    # Mutual Learning
    for epoch in range(num_epochs):
        for label_int in range(n_components):
            mask = (labels == label_int)
            X_batch = dataset[mask]
            data = torch.tensor(X_batch).float().to(device)
            
            #if use_mutual:
            dms[label_int].train()
            
            loss, x_noise, t_noise = dms[label_int](data)

            mutual_err_sum = 0
            for i in range(n_components):
                if i != label_int:
                    error_term = dms[label_int].score(x_noise, t_noise)-dms[i].score(x_noise, t_noise)
                    norm_square = torch.exp(t_noise) * error_term.norm(p=2, dim=1)**2
                    mutual_err_sum +=  torch.mean(norm_square)

            loss_m = loss + tau * mutual_err_sum/(n_components - 1)
            
            optimizers[label_int].zero_grad()
            loss_m.backward()
            optimizers[label_int].step()

            grads = []
            
            for name, param in dms[label_int].named_parameters():
                if param.grad is not None:
                    grads.append(param.grad.view(-1))

            all_grads = torch.cat(grads)
            min_norms[label_int] = all_grads.norm().item()
                            
            dms[label_int].eval()

        min_norms_value = np.array(list(min_norms.values()))

        if epoch == 0:
            max_grad_norm = min_norms_value.max()
            min_grad_norms.append(max_grad_norm)
            for label_int in range(n_components):
                ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}.pth")
                torch.save(dms[label_int].state_dict(), ckpt_path)
        else:
            if min_norms_value.max() <= max_grad_norm:
                max_grad_norm = min_norms_value.max()
                
                for label_int in range(n_components):
                    ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}.pth")
                    torch.save(dms[label_int].state_dict(), ckpt_path)       

            min_grad_norms.append(max_grad_norm) 

    
    plt.plot(range(1, 1+num_epochs), min_grad_norms)
    plt.xlabel("Epoch", fontsize=20)
    plt.ylabel("Min worst grad norm", fontsize=20)

    #plt.show()

    with torch.no_grad():
        eps = 0.065
        grid_size = 50
        left = 10.
        point = [-left + (2 * left / grid_size) * i for i in range(grid_size + 1)]
        
        seaborn.kdeplot(dataset.reshape(-1,), label = 'Target')

        for i in range(n_components):

            target_dens = torch.zeros(grid_size + 1)
            for j in range(grid_size + 1):
                #target_dens[j] = weights_mix[0] / torch.sqrt(2 * torch.tensor(torch.pi) * covs_mix[0]) * torch.exp(-(point[j] - torch.tensor(means_mix[0])) ** 2 / 2 /covs_mix[0]) + weights_mix[1] / torch.sqrt(2 * torch.tensor(torch.pi) * covs_mix[1]) * torch.exp(-(point[j] - torch.tensor(means_mix[1])) ** 2 / 2 /covs_mix[1]) #mix0
                target_dens[j] = 1 / torch.sqrt(2 * torch.tensor(torch.pi) * covs[i]) * torch.exp(-(point[j] - torch.tensor(means[i])) ** 2 / 2 /covs[i])

            # mutual kl
            kl_m = torch.zeros(31)
            kl_iter_m = 0
            
            net_path_m = os.path.join(ckpt_dir + log_path_name, f"model{i}.pth")
            ckpt = torch.load(net_path_m, map_location=device)
            dms[i].load_state_dict(ckpt)
            density = torch.zeros(grid_size + 1)

            def score(x):
                return dms[i].score(torch.tensor([[x]]), torch.tensor([eps]))
            
            for j in tqdm(range(grid_size + 1)):
                density[j], err = torch.tensor(integrate.quad(score, -100.0, point[j]))

            max = density.max()
            dd = density - max
            dd = torch.exp(dd)
            dd /= (dd.sum() / (1 / (2 * left / grid_size)))
            plt.plot(point, dd, label = f"Trained_mutual{i}")
            ddd = dd * (2 * left / grid_size)
            target_dens_ = target_dens * (2 * left / grid_size)
            kl_m[kl_iter_m] = F.kl_div(ddd.log(), target_dens_, reduction='sum')
            kl_final_m = kl_m[kl_iter_m]
            print(kl_m[kl_iter_m])
            kl_iter_m += 1

            logger.info(f"Epoch {epoch+1:4d} | Component {i:4d} | KL_mutual {kl_final_m:.4f}")

            target_dens = torch.zeros(grid_size + 1)
            for j in range(grid_size + 1):
                #target_dens[j] = 1 / torch.sqrt(2 * torch.tensor(torch.pi) * cov) * torch.exp(-(point[j] - torch.tensor(mean)) ** 2 / 2 /cov) #mix0
                target_dens[j] = 1 / torch.sqrt(2 * torch.tensor(torch.pi) * covs[i]) * torch.exp(-(point[j] - torch.tensor(means[i])) ** 2 / 2 /covs[i])

            # vanilla kl
            kl_v = torch.zeros(31)
            kl_iter_v = 0

            net_path_v = os.path.join(ckpt_dir + log_path_name, f"model{i}_vanilla.pth")
            ckpt = torch.load(net_path_v, map_location=device)
            dms[i+n_components].load_state_dict(ckpt)
            density = torch.zeros(grid_size + 1)

            def score(x):
                return dms[i+n_components].score(torch.tensor([[x]]), torch.tensor([eps]))

            for j in tqdm(range(grid_size + 1)):
                density[j], err = torch.tensor(integrate.quad(score, -100.0, point[j]))

            max = density.max()
            dd = density - max
            dd = torch.exp(dd)
            dd /= (dd.sum() / (1 / (2 * left / grid_size)))
            plt.plot(point, dd, label = f"Trained_vanilla{i}")
            plt.legend()
            ddd = dd * (2 * left / grid_size)
            kl_v[kl_iter_v] = F.kl_div(ddd.log(), target_dens_, reduction='sum')
            kl_final_v = kl_v[kl_iter_v]
            print(kl_v[kl_iter_v])
            kl_iter_v += 1
            
            logger.info(f"Epoch {epoch+1:4d} | Component {i:4d} | KL_vanilla {kl_final_v:.4f}")

        
        plt.savefig('epoch_{}.png'.format(epoch), dpi=300)
        plt.close()

        logger.info('-' * 19)

if __name__ == "__main__":
    main()