import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import signal
from scipy.linalg import hankel
from tqdm import trange
from itertools import permutations
try:
    import pyinform
except:
    pass

import warnings

from sklearn import metrics

# import pyinform
import statsmodels.api as sm
import torch
from torch import nn
import torch.nn.functional as F


def exp_basis(decay: float, window_size: int, time_span: float):
    """Exponential decay basis.
    
    \\phi(t) = \\beta exp(-\\beta t)

    Parameters
    ----------
    decay : float
        Decay parameter.
    window_size : int
        Number of time bins descretized.
    time_span : float
        Max influence time span.

    Returns
    -------
    basis : ndarray of shape (window_size,)
        Descretized basis.
    """

    basis = torch.zeros(window_size)
    dt = time_span / window_size
    t = torch.linspace(dt, time_span, window_size)
    basis = torch.exp(-decay * t)
    basis /= (dt * basis.sum(axis=0)) # normalization
    return basis


def convolve_spikes_with_basis(spikes_list: torch.FloatTensor, basis: torch.FloatTensor) -> torch.FloatTensor:
    """Convolve spike train spikes_list[:, :, j] with a single basis.
    Parameters
    ----------
    spikes_list : torch.FloatTensor of shape (n_seq, n_time_bins, n_neurons)
        Spike train.
    Returns
    -------
    convolved_spikes_list : torch.FloatTensor of shape (n_time_bins, n_neurons)
        Convolved spike train.
    """

    window_size = len(basis)
    n_seq, n_time_bins, n_neurons = spikes_list.shape
    convolved_spikes_list = np.zeros_like(spikes_list)
    for i in range(n_seq):
        for j in range(n_neurons):
            convolved_spikes_list[i, 1:, j] = np.convolve(spikes_list[i, :, j], basis)[:-window_size]
    return torch.from_numpy(convolved_spikes_list)


def log_likelihood(spikes: torch.FloatTensor, firing_rates: torch.FloatTensor, distribution='Poisson') -> torch.FloatTensor:
    if distribution == 'Poisson':
        return spikes * (firing_rates + 1e-8).log() - firing_rates - torch.lgamma(spikes+1)
    elif distribution == 'Bernoulli':
        return spikes * (firing_rates + 1e-8).log() + (1-spikes) * (1-firing_rates + 1e-8).log()


def neg_log_likelihood(spikes: torch.FloatTensor, firing_rates: torch.FloatTensor, distribution='Poisson') -> torch.FloatTensor:
    if distribution == 'Poisson':
        nll = torch.sum(-spikes * firing_rates.log() + firing_rates + torch.lgamma(spikes+1), dim=(-2, -1))
    elif distribution == 'Bernoulli':
        nll = -torch.sum(spikes * firing_rates.log() + (1-spikes) * (1-firing_rates).log(), dim=(-2, -1))
    return nll


def accuracy_score(edges_true, edges_pred):
    return metrics.accuracy_score(edges_true.flatten(), edges_pred.flatten())


def balanced_accuracy_score(edges_true, edges_pred):
    return metrics.balanced_accuracy_score(edges_true.flatten(), edges_pred.flatten())


def f1_score(edges_true, edges_pred):
    return metrics.f1_score(edges_true.flatten(), edges_pred.flatten(), average='macro')


def coincidence_indicator(spikes, window_size_1=2, window_size_2=99):
    n_time_bins, n_neurons = spikes.shape
    result = np.zeros((n_neurons, n_neurons))
    for i in range(n_neurons):
        for j in range(n_neurons):
            corr = signal.correlate(spikes[:, i], spikes[:, j])
            midpoint = int(len(corr) / 2)
            result[i, j] = np.sum(corr[midpoint+1:midpoint+1+window_size_1]) / np.sum(corr[midpoint+1:midpoint+1+window_size_2])
    return result


def ccg_diff(spikes, window_size=13):
    n_time_bins, n_neurons = spikes.shape
    result = np.zeros((n_neurons, n_neurons))
    for i in range(n_neurons):
        for j in range(n_neurons):
            corr = signal.correlate(spikes[:, i], spikes[:, j])
            midpoint = int(len(corr) / 2)
            result[i, j] = np.mean(corr[midpoint+1:midpoint+1+window_size]) - np.mean(corr[midpoint-1:midpoint-1-window_size:-1])
    return result


def mutual_information(spikes):
    n_time_bins, n_neurons = spikes.shape
    result = np.zeros((n_neurons, n_neurons))
    for i in range(n_neurons):
        for j in range(n_neurons):
            result[i, j] = pyinform.mutual_info(spikes[:, i], spikes[:, j])
    return result


def transfer_entropy(spikes, k=5):
    n_time_bins, n_neurons = spikes.shape
    result = np.zeros((n_neurons, n_neurons))
    for i in range(n_neurons):
        for j in range(n_neurons):
            result[i, j] = pyinform.transferentropy.transfer_entropy(spikes[:, j], spikes[:, i], k=k)
    return result


def poisson_glm(spikes, lag=5):
    n_time_bins, n_neurons = spikes.shape
    result = np.zeros((n_neurons, n_neurons))
    for i in range(n_neurons):
        y = spikes[:, i]
        X = [np.ones((n_time_bins, 1))]
        for j in np.arange(n_neurons):
            padded_spikes = np.hstack((np.zeros(lag), spikes[:, j]))
            X.append(hankel(padded_spikes[:n_time_bins], padded_spikes[n_time_bins-1:-1]))
        X = np.hstack(X)
    
        poisson_model = sm.GLM(y, X, family=sm.families.Poisson())
        poisson_results = poisson_model.fit()
        theta_lag = poisson_results.params[1:]
    
        for j in range(n_neurons):
            result[i, j] = np.max(theta_lag[j*lag:(j+1)*lag])
    return result


def half_of_max(continuous_pred: torch.FloatTensor):
    warnings.warn('Please use `utils.weight_to_adjacency_index` instead.', DeprecationWarning)
    return torch.bucketize(continuous_pred, torch.tensor([-torch.inf, continuous_pred.min()*0.5, continuous_pred.max()*0.5, torch.inf])) - 2


def weight_to_adjacency(weight):
    adjacency = torch.zeros(list(weight.shape) + [3])
    idx = weight > 0
    adjacency[:, :, :, 2][idx] = weight[idx] / weight.max()
    adjacency[:, :, :, 1][idx] = 1 - adjacency[:, :, :, 2][idx]
    adjacency[:, :, :, 0][~idx] = weight[~idx] / weight.min()
    adjacency[:, :, :, 1][~idx] = 1 - adjacency[:, :, :, 0][~idx]
    return adjacency


def weight_to_adjacency_index(weight: torch.FloatTensor, cut_pos=0.5, upperbound=None):
    if upperbound is None:
        negative_cut = weight.min() * cut_pos
        positive_cut = weight.max() * cut_pos
    else:
        negative_cut = -upperbound * cut_pos
        positive_cut = upperbound * cut_pos
    return torch.bucketize(weight, torch.tensor([-torch.inf, negative_cut, positive_cut, torch.inf])) - 2


def discrete_to_continuous(spikes: np.array, discrete_states: np.array, dt: float) -> tuple:
    n_time_bins, n_neurons = spikes.shape
    states_n = [[] for i in range(n_neurons)]
    states = []
    points_hawkes = [[] for i in range(n_neurons)]

    for i in trange(n_time_bins):
        for j in range(n_neurons):
            if spikes[i, j] == 1:
                points_hawkes[j].append(i*dt+j%n_neurons)
                states_n[j].append(discrete_states[i])
                states.append(discrete_states[i])
    states.append(discrete_states[-1])
    return points_hawkes, states, states_n


# def continuous_to_discrete(points_hawkes: list, states: list, states_n: list, dt: float, T: float) -> tuple:
#     n_time_bins = int(T / dt)
#     time_bins = np.linspace(0, T, n_time_bins+1)
#     n_neurons = len(states_n)
#     spikes = np.zeros((n_time_bins, n_neurons))

#     for neuron in range(n_neurons):
#         spikes[:, neuron] = np.histogram(points_hawkes[neuron], bins=time_bins)[0]
    
#     time_stamps = np.sort(np.concatenate(points_hawkes))
#     idx = np.searchsorted(time_stamps, np.linspace(dt/2, T-dt/2, n_time_bins))
#     discrete_states = np.array(states)[idx]
#     return spikes, discrete_states


def continuous_to_discrete(timestamps_list: list, dt: float, T: float, start: float = 0) -> np.ndarray:
    """Convert timestamps spike data to discretized spike count data.
    
    Parameters
    ----------
    timestamps_list : list of shape (n_neurons,)
        Spiking time for each neuron.
    dt : float
        Width of time bins.
    T : float
        Final time.

    Returns
    -------
    spikes : ndarray of shape (n_time_bins, n_neurons)
        Discretized spike count.
    """
    n_time_bins = int(T / dt)
    time_bins = np.linspace(start, T + start, n_time_bins+1)
    n_neurons = len(timestamps_list)
    spikes = np.zeros((n_time_bins, n_neurons))

    for neuron in range(n_neurons):
        spikes[:, neuron] = np.histogram(timestamps_list[neuron], bins=time_bins)[0]
    return spikes


def match_states(one_hot_true_states: torch.LongTensor, gamma: torch.FloatTensor, force=True):
    """Match the 
    Parameters
    ----------
    one_hot_true_states : torch.LongTensor of shape (n_seq, n_time_bins, n_states) or (n_time_bins, n_states)
        One-hot true state sequence(s).
    gamma : torch.FloatTensor of shape (n_seq, n_time_bins, n_states) or (n_time_bins, n_states)
        One-hot posteior probability or one-hot predicted state sequence(s).
    Returns
    -------
    true_to_learned : torch.LongTensor of shape (n_states,)
        `true_to_learned[s]` represents the state in the learned model that corresponds to the state `s` in the original model.
    """

    if len(gamma.shape) == 2:
        one_hot_true_states = one_hot_true_states[None, :]
        gamma = gamma[None, :, :]
    n_states = gamma.shape[2]
    true_to_learned = torch.zeros(n_states, dtype=torch.int64)
    if force is True:
        all_possible_permutations = torch.tensor(list(permutations(range(n_states))))
        n_possible_permutations = len(all_possible_permutations)
        mse_list = torch.zeros(n_possible_permutations)
        for permutation in range(n_possible_permutations):
            mse_list[permutation] = (one_hot_true_states - gamma[:, :, all_possible_permutations[permutation]]).square().mean()
        true_to_learned = all_possible_permutations[mse_list.argmin()]
    else:
        for state in range(n_states):
            true_to_learned[state] = (one_hot_true_states - gamma[:, :, [state]]).square().mean(dim=0).argmin()
    return true_to_learned


def visualize_edges(edge_matrix, fig=None, ax=None, v_bound=None):
    if fig is None:
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    if v_bound is None:
        v_min, v_max = None, None
    else:
        v_min, v_max = -v_bound, v_bound

    n_neurons = len(edge_matrix)
    im = ax.matshow(edge_matrix, cmap='seismic', vmin=v_min, vmax=v_max)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im, cax=cax, orientation='vertical')
    ax.set_xlabel('pre')
    ax.set_ylabel('post')
    ax.xaxis.set_label_position('top')
    # ax.set_xticks(np.arange(n_neurons))
    # ax.set_xticklabels(np.arange(n_neurons) + 1)
    # ax.set_yticks(np.arange(n_neurons))
    # ax.set_yticklabels(np.arange(n_neurons) + 1)


def visualize_vector(v, fig=None, ax=None, v_bound=None):
    if fig is None:
        fig, ax = plt.subplots(1, 1, figsize=(1, 4))
    if v_bound is None:
        v_min, v_max = None, None
    else:
        v_min, v_max = -v_bound, v_bound
    im = ax.matshow(v.reshape((len(v), 1)), cmap='seismic', vmin=v_min, vmax=v_max)
    ax.get_xaxis().set_visible(False)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='50%', pad=0.1)
    fig.colorbar(im, cax=cax, orientation='vertical')


def visualize_spikes(spikes, firing_rates_pred, firing_rates=None, fig=None, ax=None, n_neurons_plot=None, n_time_bins_plot=None) -> None:
    n_time_bins, n_neurons = spikes.shape
    if n_neurons_plot is None:
        n_neurons_plot = n_neurons
    if n_time_bins_plot is None:
        n_time_bins_plot = n_time_bins
    if fig is None:
        fig, axs = plt.subplots(n_neurons_plot, 1, figsize=(10, 3*n_neurons_plot), sharex=True)
    n_time_bins = spikes.shape[0]
    for neuron in range(n_neurons_plot):
        axs[neuron].plot(spikes[:n_time_bins_plot, neuron] * 0.5, label='spikes')
        axs[neuron].plot(firing_rates_pred[:n_time_bins_plot, neuron], label='predicted firing rates')
        if firing_rates is not None:
            axs[neuron].plot(firing_rates[:n_time_bins_plot, neuron], label='firing rates')
        axs[neuron].set_ylabel(f"neuron {neuron + 1}")
    plt.xlabel("$t$")
    axs[0].legend()
    # plt.suptitle("Conditional intensity (firing rates / dt) and spike trains for all neurons")
    return fig
