from random import vonmisesvariate
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
import numpy as np
from numpy.random import normal, dirichlet, beta, multinomial, multivariate_normal
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import os
from flows import *
import sys

# print('PRINTS WILL BE SAVED TO ouput.txt')
# file_path = 'output.txt'
# open('output.txt', 'w').close() # empties the file
# sys.stdout = open(file_path, "w")

sns.set();


class HDPHMM:
    def __init__(self, alpha, gma, rho, lamda, n_features, device, n_MADE_blocks=1, periodic=False,
                 MADE_input_size=None, MADE_hidden_size=8, MADE_n_hidden=1, MADE_batch_norm=False):
        """
        Hierarchical non-parametric generative model for time series

        :param lmbda: Parameters of the base distribution H.
            H is a multivariate normal distribution where lmbda[0] is the mean and lmbda[1] indicates the covariance
        :param alpha: Top-level DP parameter
        :param gma: Lower-level DP parameter
        :param rho: Self transition parameter of the Beta distribution
            (kappa[0] is the  logistic function max value and kappa[1] is the logistic growth rate)
        :param lamda: Mean and Standard deviation of the base distribution H
        :param n_features: The size of the multivariate time series data features
        :param nf_model: Normalizing flow model that approximates the data distribution p(x) from the
            distribution of representations p(s)
        """

        self.alpha = alpha
        self.gma = gma
        self.rho = rho
        self.lamda = lamda
        self.k_max = 50
        self.kappa = np.array([beta(rho[0], rho[1]) for _ in range(self.k_max)])
        self.n_features = n_features
        self.device = device
        self.periodic = periodic
        self.MADE_input_size = MADE_input_size
        self.MADE_hidden_size = MADE_hidden_size
        self.MADE_n_hidden = MADE_n_hidden
        self.MADE_batch_norm = MADE_batch_norm
        self.n_MADE_blocks = n_MADE_blocks
        if MADE_input_size is None:
            MADE_input_size = self.n_features

        
        self.flows = WrappedMAF(n_blocks=n_MADE_blocks, # Number of MADE blocks
                                input_size=MADE_input_size, # dimension of distribution you'd like to learn
                                hidden_size=MADE_hidden_size, # dimension of hidden layers in the MADE Blocks
                                n_hidden=MADE_n_hidden, # Number of hidden layers in each MADE block
                                periodic=periodic,
                                cond_label_size=1, # Size of context information we feed at each time step. None if you don't want any.
                                batch_norm=MADE_batch_norm # True iff we want a batch norm layer after each MADE block
                            ).to(self.device)
        s = 0
        self.flows.eval()
        for param in self.flows.parameters():
            s += torch.numel(param) # total number of elements in param
            param.requires_grad = False
            # torch.nn.init.constant_(param, 0)
        print('Number of parameters in the NF (theta_size): ', s)
        self.theta_size = s
        self.s_size = s

        self.beta = self._gem(gma, self.k_max)
        k_max = len(self.beta)
        self.theta = multivariate_normal(mean=np.zeros(self.theta_size)*lamda[0],
                                         cov=np.diag(np.ones(self.theta_size))*lamda[1], size=k_max)
        self.pi = np.zeros((k_max, k_max))

        for i in range(len(self.beta)):
            # stick_breaking = self._gem(alpha)
            eps = self._gem(alpha, 800)
            # eps = np.array([next(stick_breaking) for _ in range(800 * k_max)])
            for j in range(len(eps)):#800 * k_max):  # Change this!
                phi = np.where(multinomial(1, dirichlet(self.beta)))[0][0]
                self.pi[i][phi] += eps[j]

    def _gem(self, gma, lim=None):
        """
        Stick-breaking process with parameter gma.
        """
        beta_weights = []
        prev = 1
        if lim is None:
            while (prev>0.1 or len(beta_weights)<self.k_min):
                beta_k = beta(1, gma) * prev
                prev -= beta_k
                beta_weights.append(beta_k)
        else:
            while (len(beta_weights) < lim):
                beta_k = beta(1, gma) * prev
                prev -= beta_k
                beta_weights.append(beta_k)
            beta_weights.append(prev)
        return np.array(beta_weights)

    def fill_flow_params(self, theta_vector, nf=None):
        """
        Fills the parameters of self.flows using a theta_vector. 
        theta_vector must be a k dimensional vector, where k is the number of total trainable parameters in
        self.flows

        Set nf to a separate flow if you wish to fill those flow parameters. Else, it'll do it for self.flows
        """
        # flow_model = nf if nf is not None else self.flows
        
        curr_ind = 0
        with torch.no_grad():
            for param in self.flows.parameters():
                size = torch.numel(param) # total number of elements in param
                # param.data = theta_vector[curr_ind: curr_ind + size].reshape(param.shape) # Sets the data of the param, i.e. the tensor still has requires_grad=True but now has new values.
                param.copy_(theta_vector[curr_ind: curr_ind + size].reshape(param.shape))
                curr_ind += size
        self.flows.eval()

    def _nf(self, z_t, z_count=None, verbose=False):
        """
        Time series generative process

        z_t is a scalar
        base_dist_params is a tuple of size 2.
          The first element is a torch Tensor of shape (self.n_features,)
          The second element is a torch Tensor of shape (self.n_features,)
        z_count is a scalar. It is the number of steps the time series has been in the current state
        """
        # theta_t = torch.Tensor(self.theta[z_t]).to(self.device)
        # if verbose:
        #     print('theta_t: ', theta_t)

        self.fill_flow_params(theta_vector=torch.Tensor(self.theta[z_t]).to(self.device))

        cond = torch.Tensor([z_count]).reshape(1, 1).to(self.device) if z_count is not None else None
        x = self.flows.sample(num_samples=1, context=cond)
        return x[0]

    # def log_prob_batch_test(self, x, theta_list, z=None, cond=None, mini_batch_size=1):
    #     px_all = 0
    #     for k_ind in range(theta_list.shape[1]):
    #         px = Normal(loc=theta_list[:,k_ind,:].unsqueeze(1),
    #                     scale=torch.ones_like(theta_list[:,k_ind,:].unsqueeze(1))).log_prob(x)
    #         px_all += px.sum(-1).sum(-1)
    #     return px_all

    def log_prob_batch(self, x, theta_list, z=None, cond=None, mini_batch_size=1):
        """
        Takes in data x of shape (n_samples, T, num_features)
        theta_list is of shape (n_samples, k_max, theta_size)
        z is the underlying state of shape (n_samples, T)
        cond is conditioning information of shape (num_samples, T)

        outputs a tensor of shape (num_samples,)
        """

        # flow = WrappedMAF(n_blocks=self.n_MADE_blocks,  # Number of MADE blocks
        #                   hidden_size=self.MADE_hidden_size,
        #                   input_size=self.MADE_input_size,
        #                   n_hidden=self.MADE_n_hidden,  # Number of hidden layers in each MADE block
        #                   periodic=self.periodic,
        #                   cond_label_size=1,
        #                   batch_norm=self.MADE_batch_norm).to(self.device)

        with torch.no_grad():
            all_probs = []
            for sample_ind, theta_mc in enumerate(theta_list):
                # Looping over all samples
                out = []
                for theta_ind, theta in enumerate(theta_mc):
                    # Looping over all possible states
                    # with torch.no_grad():
                    #     curr_ind = 0
                    #     for param in flow.parameters():
                    #         size = torch.numel(param)  # total number of elements in param
                    #         param.copy_(theta[curr_ind: curr_ind + size].reshape(param.shape))
                    #         curr_ind += size
                    self.fill_flow_params(theta_vector=theta)
                    p = self.flows._log_prob(x[sample_ind],
                                             cond[sample_ind].unsqueeze(-1) if cond is not None else None)
                        # p = flow._log_prob(x[sample_ind],
                        #                    cond[sample_ind].unsqueeze(-1) if cond is not None else None)
                    out.append(p)
                    # out.append(torch.clamp(input=p, min=-10000, max=10000))
                probs = torch.stack(out, -1)  # [T, k]
                if z is None:
                    all_probs.append(probs)
                else:
                    all_probs.append(probs[np.arange(len(z[sample_ind])), z[sample_ind]])
        return torch.stack(all_probs)

    # def log_prob(self, x, theta_list, z_t=None, cond=None, mini_batch_size=1):
    #     """
    #     Takes in data x of shape (T, num_features), or (1, num_features)
    #     theta is a batch of theta vetors for each time step. Should be of shape (num_samples,k,  theta_size)
    #     cond is conditioning information. See the definition of the MAF for the shape of this cond information. Should be of shape (num_samples, cond_size)
    #
    #     outputs a tensor of shape (num_samples,)
    #     """
    #     with torch.no_grad():
    #         all_probs = []
    #         for sample_ind in range(len(theta_list)//mini_batch_size):
    #         # for sample_ind, theta_mc in enumerate(theta_list):
    #             theta_mc = theta_list[sample_ind*mini_batch_size:(sample_ind+1)*mini_batch_size].mean(0)
    #             if not cond is None:
    #                 cond = cond.to(self.device)
    #                 if (len(cond.shape) == 1):
    #                     sample_cond = cond[sample_ind].unsqueeze(-1)
    #             else:
    #                 sample_cond = None
    #             if z_t is None:
    #                 out = []
    #                 for theta in theta_mc:
    #                     self.fill_flow_params(theta_vector=theta)
    #                     out.append(self.flows.log_prob(x[sample_ind*mini_batch_size:(sample_ind+1)*mini_batch_size].to(self.device), sample_cond)/mini_batch_size)
    #                     # out.append(self.flows.log_prob(x[sample_ind].unsqueeze(0).to(self.device), sample_cond))
    #                 probs = torch.stack(out, -1)
    #             else:
    #                 theta = theta_mc[z_t[sample_ind]]
    #                 self.fill_flow_params(theta_vector=theta)
    #                 probs = self.flows.log_prob(x[sample_ind*mini_batch_size:(sample_ind+1)*mini_batch_size].to(self.device), sample_cond)/mini_batch_size
    #             all_probs.extend(probs)
    #     return torch.stack(all_probs)
    #     # if not cond is None:
    #     #     cond = cond.to(self.device)
    #     #     if (len(cond.shape) == 1):
    #     #         cond = cond.reshape(-1, 1)
    #     # out = []
    #     # for theta in theta_list:
    #     #     self.fill_flow_params(theta_vector=theta)
    #     #     out.append(self.flows.log_prob(x.to(self.device), cond))
    #     # all_probs = torch.stack(out, -1)
    #     # if z_t is None:
    #     #     return all_probs
    #     # else:
    #     #     return all_probs[torch.arange(len(z_t)),z_t]
    #
    #     # num_samples = x.shape[0]
    #     # for i in range(num_samples):
    #     #     theta_t = theta[i]
    #     #     self.fill_flow_params(theta_vector=theta_t)
    #     #     # NOTE This function currently only works when the base distribution does *not* change over time!
    #     #     out.append(self.flows.log_prob(x[i], cond[i]))
    #     # return torch.stack(out)

    def generate(self, T, verbose=False):
        """
        Simulate a time series sample
        """
        base_distr_means = []
        base_distr_vars = [] # Store the base distributions for the flow. This changes each time step. 

        sample_x = []
        
        sample_z = list(np.where(multinomial(1, dirichlet(self.beta)))[0])
       
        #p_s = MultivariateNormal(loc=torch.Tensor(self.theta[sample_z[-1]]),
        #                         scale_tril=torch.diag(torch.ones(len(self.theta[sample_z[-1]]))*0.5))
        #sample_s = [list(p_s.sample().numpy())]
        # TODO: Need to change base_dist_params.
        x_t = self._nf(z_t=sample_z[0], z_count=torch.Tensor([1]), verbose=verbose)
        x_t = x_t.cpu().detach().numpy()
        if verbose:
            print('x_t: ', x_t)
        sample_x.append(x_t) 
        z_count = 1
        for _ in range(1, T):
            if verbose:
                print('Time Step: ', _)
            # Estimate the underlying state and encourage self-transition
            # self_trans = np.zeros_like(self.pi[sample_z[-1], :])
            # self_trans[sample_z[-1]] += self.kappa[0]/(1+np.exp(-self.kappa[1]*z_count))
            self_trans = np.diag(self.kappa)
            z_t = np.where(multinomial(1, dirichlet(self.pi[sample_z[-1], :] + self_trans[sample_z[-1]] + 1e-22)))[0][0]
            if z_t == sample_z[-1]:
                z_count += 1
                # s_t = sample_s[-1]
            else:
                # p_xt, x_t = self._nf(p_s, z_count)
                # sample_x.append(x_t)
                z_count = 1
            p_s = torch.distributions.MultivariateNormal(loc=torch.Tensor(self.theta[z_t]),
                                                         scale_tril=torch.diag(torch.ones(len(self.theta[z_t]))*0.5))
            
            x_t = self._nf(z_t=sample_z[-1], z_count=torch.Tensor([z_count]), verbose=verbose)
           
            
            x_t = x_t.cpu().detach().numpy()
            sample_x.append(x_t)
            sample_z.append(z_t)
            #sample_s.append(s_t)
        # p_xt, x_t = self._nf(p_s, z_count)
        # sample_x.append(x_t)
        return np.array(sample_z), np.stack(sample_x, axis=0)
    #
    # def posterior_predictive(self, x_hat, n_mc_sample=20, x_hat_lens=None):
    #     if x_hat_lens is None:
    #         x_hat_lens = torch.Tensor([x_hat.shape[1]] * len(x_hat)).to(self.device)
    #     N, T, _ = x_hat.shape
    #     posterior_like = []
    #     for n_mc in range(n_mc_sample):
    #         init_probs = torch.Tensor(self._gem(self.gma, 19))
    #         theta = multivariate_normal(mean=np.zeros(self.theta_size) * self.lamda[0],
    #                                     cov=np.diag(np.ones(self.theta_size)) * self.lamda[1], size=20)
    #         pi = np.zeros((20, 20))
    #         for i in range(len(init_probs)):
    #             # stick_breaking = self._gem(alpha)
    #             eps = self._gem(self.alpha, 800)
    #             for j in range(len(eps)):  # 800 * k_max):  # Change this!
    #                 phi = np.where(multinomial(1, dirichlet(init_probs)))[0][0]
    #                 pi[i][phi] += eps[j]
    #         kappa = np.array([beta(self.rho[0], self.rho[1]) for _ in range(20)])
    #         # init_probs = torch.Tensor(self.beta)[:20]
    #         pi, kappa, theta = torch.Tensor(pi), torch.Tensor(kappa), torch.Tensor(theta)
    #         A = pi * (1 - kappa).unsqueeze(-1) + torch.diag(kappa).to('cpu')
    #         px_X = self.forward_backward(x_hat, A, theta, init_probs=init_probs, lens=x_hat_lens, forward_only=True)
    #         posterior_like.append(px_X)
    #     log_like = torch.logsumexp(torch.stack(posterior_like), axis=0)
    #     return log_like.mean(0).item(), log_like.std(0).item()
    #
    # def forward_backward(self, obs_seq, A, theta, init_probs, lens, forward_only=False):
    #     device = 'cpu'
    #     N, T, F = obs_seq.shape
    #     obs_seq = obs_seq.to(device)
    #     ## Forward
    #     z_count = torch.ones((1, N)).to(device)
    #     alpha = torch.zeros((N, T, 20)).to(device)
    #     logp_x_zt_new = self.log_prob_batch(obs_seq, theta.unsqueeze(0).repeat(N, 1, 1),
    #                                                     z=None, cond=torch.ones((N, T)).to(device).float())  # [N,T,k]
    #     logp_x = [logp_x_zt_new[:, 0, :]]
    #     alpha[:, 0, :] = torch.matmul(torch.stack([init_probs] * len(obs_seq)), A) * (torch.exp(logp_x[-1])+1e-25)
    #     cs = torch.log(torch.sum(alpha[:, 0, :], -1, keepdim=True))
    #     alpha[:, 0, :] = alpha[:, 0, :] / torch.sum(alpha[:, 0, :], -1, keepdim=True)
    #     log_like = torch.zeros(N, ).to(device)
    #     t_all = []
    #     for t in range(1, T):
    #         # Logp if state didn't change at t
    #         logp_x_zt_cont = self.log_prob_batch(obs_seq[:, t, :].unsqueeze(0), theta.unsqueeze(0),
    #                                                          z=None, cond=(z_count + 1.))[0]
    #         p_x_mat = torch.exp(logp_x_zt_new[:, t, :].unsqueeze(1).repeat(1, 20, 1))
    #         # p_x_mat = torch.exp(logp_x_zt_new[:, t, :].unsqueeze(-1).repeat(1, 1, self.k_max))
    #         p_x_mat *= 1 - torch.diag_embed(torch.ones((N, 20)))
    #         p_x_mat += torch.diag_embed(torch.exp(logp_x_zt_cont))
    #         alpha[:, t, :] = torch.bmm(alpha[:, t - 1, :].unsqueeze(1), torch.mul(A.unsqueeze(0), p_x_mat+1e-25))[:, 0, :]
    #
    #         state_persist = (torch.argmax(alpha[:, t - 1, :], dim=-1) == torch.argmax(alpha[:, t, :], dim=-1))
    #         z_count = torch.where(state_persist, (z_count + 1), torch.ones(N).to(device))
    #         t_all.append(z_count[0, 0])
    #         logp_x_z_t = torch.stack([logp_x_zt_new[:, t, :], logp_x_zt_cont])[state_persist * 1, np.arange(N)]
    #         logp_x.append(logp_x_z_t)
    #         cs += torch.log(torch.sum(alpha[:, t, :], -1, keepdim=True))
    #         alpha[:, t, :] = alpha[:, t, :] / torch.sum(alpha[:, t, :], -1, keepdim=True)
    #         observed_likelihood = torch.where(lens - 1 == t,
    #                                           cs.squeeze(),
    #                                           torch.zeros(N, ).to(device))
    #         alpha[:, t, :] = torch.where(lens.unsqueeze(-1) - 1 >= t,
    #                                      alpha[:, t, :],
    #                                      torch.zeros(N, 20).to(device))
    #         log_like += observed_likelihood
    #     if forward_only:
    #         return log_like
    #
    #     ## Backward
    #     beta = torch.ones((N, T, 20)).to(device)
    #     beta[:, -1, :] = beta[:, -1, :] / torch.sum(beta[:, -1, :], -1, keepdim=True)
    #     for t in range(T - 2, -1, -1):
    #         logp_x_z_t = logp_x[t + 1]
    #         p_x_z_t = torch.exp(logp_x_z_t) + 1e-25
    #         beta[:, t, :] = torch.matmul(A, (beta[:, t + 1, :] * p_x_z_t).T).T
    #         beta[:, t, :] = beta[:, t, :] / torch.sum(beta[:, t, :], -1, keepdim=True)
    #         beta[:, t, :] = torch.where(lens.unsqueeze(-1) - 1 > t,
    #                                     beta[:, t, :],
    #                                     torch.ones(N, 20).to(device))
    #     return (alpha * beta) / torch.sum((alpha * beta), axis=-1, keepdim=True), t_all


# model = HDPHMM(lmbda=[[0, 0], [0.5, 1]], alpha=20, gma=10, kappa=[4,1], k_max=10, n_features=4, nf_model=None)
# NOTE: When using MAF as flow, 
# theta_size needs to equal n_blocks(input_size*hidden_size + hidden_size + 
#                                    n_hidden(hidden_size^2 + hidden_size) + 
#                                    2*input_size*hidden_size + 2*input_size +
#                                    hidden_size*cond_label_size) + batch_norm(2*input_size)
# Where input_size is n_featuers, hidden_size is the size of a hidden layer in a MADE block, n_hidden is the number of hidden layers in a MADE block, n_blocks is the number of MADE blocks, and batch_norm is True iff we want a BatchNorm layer at the end of MAF. cond_label_size is the size of context information we condition on at each time step.
#
# NOTE: New formula after removing learned variances in MAF:
# theta_size needs to equal n_blocks(input_size*hidden_size + hidden_size + 
#                                    n_hidden(hidden_size^2 + hidden_size) + 
#                                    input_size*hidden_size + input_size +
#                                    hidden_size*cond_label_size) + batch_norm(2*input_size)
#


if __name__=="__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    N = 650
    n_features = 4
    T = 40
    k_max = 10
    model = HDPHMM(alpha=30, gma=8, kappa=[8,1], k_min=k_max, n_features=n_features, theta_size=30, device=device)
    x_all, z_all = [], []
    for _ in range(N):
        z, x = model.generate(T, verbose=False)
        x_all.append(x)
        z_all.append(z)
    x_all = np.array(x_all)
    z_all = np.array(z_all)
    with open('./data/sim_hard_x.npy', 'wb') as f:
        np.save(f, x_all)
    with open('./data/sim_hard_z.npy', 'wb') as f:
        np.save(f, z_all)

    if not os.path.exists('./plots'):
        os.makedirs('./plots')

    fig, axs = plt.subplots(8, 1, figsize=(20, 8))
    theta_space = ["S%d" % i for i in range(model.k_min)]
    sns.barplot(x=theta_space, y=model.beta, ax=axs[0])
    axs[0].set_ylabel("G")
    axs[0].set_ylim(0, 0.5)
    for i in range(1, 8):
        sns.barplot(x=theta_space, y=model.pi[i - 1], ax=axs[i])
        axs[i].set_ylabel("G%d" % (i - 1))
        axs[i].set_ylim(0, 0.5)
    plt.tight_layout()
    plt.savefig("./plots/dp_distributions.pdf")

    plt.figure(figsize=(10, 3))
    plt.xticks(ticks=np.array([10*i for i in range(T//10)]), labels = np.array([10*i for i in range(T//10)]))
    for i in range(x.shape[-1]):
        plt.plot(x[:,i], label="Feature %d"%i)

    color = iter(plt.cm.rainbow(np.linspace(0, 1, 10)))
    z_colors={}
    for t in range(T-1):
        if not z[t] in z_colors.keys():
            z_colors[z[t]] = next(color)
        plt.axvspan(t, t+1, alpha=0.5, color=z_colors[z[t]])
    plt.title("Time series sample with underlying states")
    plt.savefig("./plots/ts_sample.pdf")


    plt.figure()
    N = 50
    s_dist = []
    z_dist = []
    for _ in range(N):
        s, z, _ = model.generate(50)
        s_dist.append(s)
        z_dist.append(z)
    s_dist = np.array(s_dist)
    z_dist = np.array(z_dist)
    sns.scatterplot(x=s_dist[:,:,0].reshape(-1,), y=s_dist[:,:,1].reshape(-1,), hue=z_dist.reshape(-1,))
    plt.title("Distribution of representations (S) of %d samples of length %d"%(N, T))
    plt.savefig("./plots/representation_distribution.pdf")
