from typing import Optional

import os
import sys

import numpy as np
import torch
from torch.optim import LBFGS

from tqdm import trange

import pickle

import time

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from sklearn.linear_model import LinearRegression

sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))

from models.gpfa import (
    VariationalGPFA, 
    SparseVariationalGPFA, 
    InfiniteSparseVariationalGPFA, 
    DoublyInfiniteSparseVariationalGPFA, 
)
from utils.data_utils import (
    generate_rates_gpfa, 
    generate_rates_gpfa_binary_activation, 
    generate_binary_mask, 
)
from utils.kernels import SquaredExponential
from utils.likelihoods import (
    Poisson, 
    Gaussian,
    Gaussian_with_link, 
)
from utils.constants import np_float_type


def main(
    dataset: str = "sinusoidal", 
    conditional_likelihood: str="Poisson", 
    model_name: str = "svGPFA", 
    opt: str = "Adam", 
    num_epochs: int = 1000, 
    num_inducing: int = 30, 
    learning_rate: float = 1e-2, 
    seed: int = 0, 
    latent_dim: Optional[int] = None, 
    fig_dir: Optional[str] = None, 
    log_dir: Optional[str] = None, 
):
    D = 4
    num_trials = 1
    M = 50
    C = np.random.randn(M, D).astype(np_float_type)
    d = np.random.rand(M, 1).astype(np_float_type)
    t_max = 20
    bin_width=0.005

    fs = [lambda x:np.sin(x)**3, lambda x:np.cos(3.*x), lambda x:np.sin(3.*x), lambda x:np.cos(x)**3]
    if dataset == "sinusoidal":
        X, rate, log_rate = generate_rates_gpfa(
            fs, C, d, t_max, num_trials, bin_width=bin_width, rate_link=np.exp, 
        )
    elif dataset == "sinusoidal-binary":
        N = int(t_max / bin_width)
        Z = generate_binary_mask(N, D, 50, 150, 30, 100)
        
        X, rate, log_rate, Z_gt = generate_rates_gpfa_binary_activation(
            fs, C, d, t_max, num_trials, bin_width=bin_width, rate_link=np.exp, Z=Z, 
        )
    
    N = len(X)
    if conditional_likelihood == "Poisson":
        Y = np.random.poisson(rate)
    else:
        Y = rate + np.random.randn(N, M, num_trials) * 0.1
        
    # plot rate data
    plt.figure()
    plt.imshow(rate[:, :, 0].T, interpolation="nearest", extent=[0, t_max, 0, M], origin="lower", aspect="auto", cmap="gray")
    plt.colorbar()
    plt.title("binned raster plot")
    plt.xlabel("time")
    plt.ylabel("neuron index")
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    plt.savefig(os.path.join(fig_dir, "test_data.pdf"))
    plt.close()
    
    latent_dim = D if latent_dim is None else latent_dim
    
    print(f"{model_name} model | seed {seed}")
    if fig_dir is None:
        fig_dir = f"figures/{dataset}/{conditional_likelihood}"
    if log_dir is None:
        log_dir = f"logs/{dataset}/{conditional_likelihood}"
    fig_dir = os.path.join(fig_dir, dataset, conditional_likelihood, model_name, f"latent_dim_{latent_dim}")
    log_dir = os.path.join(log_dir, dataset, conditional_likelihood, model_name, f"latent_dim_{latent_dim}")
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    D = latent_dim
    
    # constructing vGPFA model
    X = torch.from_numpy(X.astype(np_float_type))
    Y = torch.from_numpy(Y.astype(np_float_type))
    kernels = []
    for d_ in range(D):
        kernels.append(SquaredExponential(input_dim=1, variance=1.0, lengthscales=1.0))
    if conditional_likelihood == "Poisson":
        likelihood = Poisson(inverse_link=torch.exp)
    elif conditional_likelihood == "Gaussian":
        likelihood = Gaussian(variance=0.1)
    elif conditional_likelihood == "Gaussian_with_link":
        likelihood = Gaussian_with_link(variance=0.1, inverse_link=torch.exp)
    else:
        raise NotImplementedError
    
    if "svGPFA" in model_name:
        num_inducings = [num_inducing for _ in range(D)]
        Z = [
            torch.from_numpy(np.linspace(0, t_max, num_inducings[d_], endpoint=True).astype(np_float_type).reshape(-1, 1))
            for d_ in range(D)
        ]
    
    if model_name == "vGPFA":
        model = VariationalGPFA(
            X=X, 
            Y=Y, 
            kernels=kernels, 
            likelihood=likelihood, 
            C=torch.from_numpy(C), 
            d=torch.from_numpy(d), 
        )
    elif model_name == "svGPFA":
        C = np.random.randn(M, D).astype(np_float_type)
        d = np.random.rand(M, 1).astype(np_float_type)
        model = SparseVariationalGPFA(
            X=X, 
            Y=Y, 
            kernels=kernels, 
            likelihood=likelihood, 
            Z=Z, 
            C=torch.from_numpy(C), 
            d=torch.from_numpy(d), 
            q_diagonal=False, 
            whitening=True,
        )
    elif model_name == "infinite-svGPFA":
        C_prior = {
            "prior_variance": 0.1, 
        }
        d_prior = {
            "prior_variance": 0.1, 
        }
        C = np.random.randn(M, D).astype(np_float_type)
        d = np.random.rand(M, 1).astype(np_float_type)
        alpha_prior = {
            "s1": 1.0, 
            "s2": 1.0, 
        }
        C_prior = None
        d_prior = None
        
        model = InfiniteSparseVariationalGPFA(
            X=X, 
            Y=Y, 
            kernels=kernels, 
            likelihood=likelihood, 
            Z=Z, 
            C=torch.from_numpy(C), 
            d=torch.from_numpy(d), 
            q_diagonal=False, 
            alpha=1.0, 
            train_alpha=True, 
            C_prior=C_prior, 
            d_prior=d_prior, 
            alpha_prior=alpha_prior, 
            m_step=True, 
            train_inducing_locs=False, 
        )
    elif model_name == "doubly-infinite-svGPFA":
        
        C_prior = {
            "prior_variance": 0.1, 
        }
        d_prior = {
            "prior_variance": 0.1, 
        }
        C = np.random.randn(M, D).astype(np_float_type)
        d = np.random.rand(M, 1).astype(np_float_type)
        alpha_prior = {
            "s1": 1.0, 
            "s2": 1.0, 
        }
        C_prior = None
        d_prior = None
        
        model = DoublyInfiniteSparseVariationalGPFA(
            X=X, 
            Y=Y, 
            kernels=kernels, 
            likelihood=likelihood, 
            Z=Z, 
            C=torch.from_numpy(C), 
            d=torch.from_numpy(d), 
            q_diagonal=False, 
            alpha=1.0, 
            train_alpha=True, 
            C_prior=C_prior, 
            alpha_prior=alpha_prior, 
            m_step=True, 
            train_inducing_locs=False, 
        )
    
    optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_history = []
    t0 = time.time()
    with trange(num_epochs, dynamic_ncols=True) as pbar:
        for e in pbar:
            model.train()
            
            optimiser.zero_grad()
            free_energy = model.variational_free_energy()
            free_energy.backward()
            optimiser.step()
            
            pbar.set_description(f"Epoch {e} | free energy {free_energy:.2f}")
            loss_history.append(free_energy.item())
    training_time = time.time() - t0
            
    f_means, f_covs, log_rates_mean, log_rates_var = model.predict_log_rates(X)
    
    f_means = f_means.detach().numpy()
    f_covs = f_covs.detach().numpy()
    log_rates_mean = log_rates_mean.detach().numpy()
    log_rates_var = log_rates_var.detach().numpy()

    x = X[:, 0].numpy()

    colors_M = plt.cm.jet(np.linspace(0,1,M))
    colors_D = plt.cm.winter(np.linspace(0,1,D))

    fig = plt.figure(figsize=(15, 15))
    ax1 = fig.add_subplot(2, 2, 1)
    ax2 = fig.add_subplot(2, 2, 2)
    ax3 = fig.add_subplot(2, 2, 3)
    ax4 = fig.add_subplot(2, 2, 4)
    
    # plot ground-truth and inferred latents (without transformation)
    for d_ in range(D):
        for r in range(num_trials):
            f_m, f_s = f_means[d_, :, r], np.sqrt(f_covs[d_, :, r])
            ax1.plot(x, f_m, color=colors_D[d_])
            try:
                ax1.plot(x, fs[d_](x), color=colors_D[d_], linestyle="--")
            except Exception:
                pass
            ax1.fill_between(x.flatten(), f_m-f_s, f_m+f_s, alpha=0.3, facecolor=colors_D[d_])
    ax1.set_xlabel("time (s)")
    ax1.set_title("True and inferred latents")

    for m in range(M):
        for r in range(num_trials):
            y = log_rates_mean[:, m, r]
            if conditional_likelihood == "Gaussian":
                ax2.scatter(np.exp(log_rate[:, m, r]).flatten(), y.flatten(), color=colors_M[m])
            elif conditional_likelihood == "Poisson":
                ax2.scatter(log_rate[:, m, r].flatten(), y.flatten(), color=colors_M[m], s=1.)
    ax2.plot([log_rates_mean.min(), log_rates_mean.max()], [log_rates_mean.min(), log_rates_mean.max()], "k--", linewidth=3, alpha=0.5)
    ax2.set_xlabel("log rate (GT)")
    ax2.set_ylabel("log rate (posterior predictive)")

    ax3.plot(loss_history, linewidth=3, color="blue")
    ax3.set_xlabel("Epochs")
    ax3.set_ylabel("Negative variational free energy")
    
    lr = LinearRegression()
    f_target = np.array([fs[d_](x) for d_ in range(len(fs))]).T
    lr.fit(f_means[:, :, 0].T, f_target)
    fitted_f_means = lr.predict(f_means[:, :, 0].T)
    score = lr.score(f_means[:, :, 0].T, f_target)
    
    for d_ in range(len(fs)):
        ax4.plot(x, fitted_f_means[:, d_], color=colors_D[d_])
        ax4.plot(x, fs[d_](x), color=colors_D[d_], linestyle="--")
        ax4.set_xlabel("time (s)")
    ax4.set_title(f"R2 score: {score:.2f}")

    fig.tight_layout()
    fig.savefig(os.path.join(fig_dir, f"results_seed{seed}.pdf"))

    plt.close()

    with open(os.path.join(log_dir, f"NI{num_inducing}_seed{seed}.pkl"), "wb") as f:
        pickle.dump((score, loss_history, training_time), f)
    f.close()
    
    torch.save(model, os.path.join(log_dir, f"final_model_seed{seed}.pt"))
    
    return model


if __name__=="__main__":
    for seed in range(10):
        model = main(
            dataset="sinusoidal-binary", 
            conditional_likelihood="Poisson", 
            model_name="doubly-infinite-svGPFA", 
            opt="Adam", 
            learning_rate=1e-2, 
            num_epochs=2000, 
            num_inducing=50, 
            seed=seed, 
            fig_dir="figures/evaluate_synthetic/", 
            log_dir="logs/evaluate_synthetic/", 
            latent_dim=4, 
        )