from collections import deque
import numpy as np
import torch
from ot import emd2 # for the MW2 metric

def state_pipeline(states, state_shape, state_type, device, unsqueeze=False):
    """
    Convert the state format into the right entry format of the model.
    """
    new_state = np.array(states)
    if(unsqueeze):
        new_state = new_state[None,:]

    if(state_type == np.uint8):
        # gymnasium Atari settings 
        return torch.ByteTensor(new_state).to(device).float() / 255.
    elif(len(state_shape) == 4):
        # gymnasium minAtari settings
        new_state = torch.tensor(new_state, device=device).transpose(4, 2).flatten(start_dim=1, end_dim=2).float()
        return new_state
    else:
        return torch.tensor(new_state).to(device)

def update_params(optim, loss, networks, retain_graph=False,
                  grad_cliping=None):
    """
    Update the parameters of each network in networks using the total loss.
    """
    optim.zero_grad()
    #print("calculating loss...")
    loss.backward(retain_graph=retain_graph)
    # Clip norms of gradients to stebilize training.
    if grad_cliping:
        for net in networks:
            #print("clip : ", grad_cliping)
            torch.nn.utils.clip_grad_norm_(net.parameters(), grad_cliping)
    optim.step()
    #print("params updated")


def disable_gradients(networks):
    """
    Disable calculations of gradients.
    """
    for network in networks:
        for param in network.parameters():
            param.requires_grad = False


def calculate_huber_loss(td_errors, kappa=1.0):
    """
    Calculate the Huber loss.
    """
    return torch.where(
        td_errors.abs() <= kappa,
        0.5 * td_errors.pow(2),
        kappa * (td_errors.abs() - 0.5 * kappa))


def calculate_quantile_huber_loss(td_errors, taus, weights=None, kappa=1.0):
    """
    Calculate the quantile Huber loss.
    """
    assert not taus.requires_grad
    batch_size, N, N_dash = td_errors.shape

    # Calculate huber loss element-wisely.
    element_wise_huber_loss = calculate_huber_loss(td_errors, kappa)
    assert element_wise_huber_loss.shape == (
        batch_size, N, N_dash)

    # Calculate quantile huber loss element-wisely.
    element_wise_quantile_huber_loss = torch.abs(
        taus[..., None] - (td_errors.detach() < 0).float()
        ) * element_wise_huber_loss / kappa
    assert element_wise_quantile_huber_loss.shape == (
        batch_size, N, N_dash)

    # Quantile huber loss.
    batch_quantile_huber_loss = element_wise_quantile_huber_loss.sum(
        dim=1).mean(dim=1, keepdim=True)
    assert batch_quantile_huber_loss.shape == (batch_size, 1)

    if weights is not None:
        quantile_huber_loss = (batch_quantile_huber_loss * weights).mean()
    else:
        quantile_huber_loss = batch_quantile_huber_loss.mean()

    return quantile_huber_loss


def evaluate_quantile_at_action(s_quantiles, actions):
    """
    Evaluate the quantiles at given actions.
    """
    assert s_quantiles.shape[0] == actions.shape[0]

    batch_size = s_quantiles.shape[0]
    N = s_quantiles.shape[1]

    # Expand actions into (batch_size, N, 1).
    action_index = actions[..., None].expand(batch_size, N, 1)

    # Calculate quantile values at specified actions.
    sa_quantiles = s_quantiles.gather(dim=2, index=action_index)

    return sa_quantiles

def normal_density(x, mean, sigma, dim=2):
    """
    Returns :
    torch.Tensor([N(x[i] | mean[j], sigma[i,j]**2) for i,j]) if dim = 2
    torch.Tensor([N(x[i,k] | mean[i,j], sigma[i,j]**2) for i,j,k]) if dim = 3
    """
    if(dim == 2):
        return torch.exp(-(x[:,:,None]-mean[:,None,:])**2/(2*sigma**2))/sigma
    elif(dim == 3):
        return torch.exp(-(x[:,None,:]-mean[:,:,None])**2/(2*sigma[:,:,None]**2))/sigma[:,:,None]


def weighted_normal(mean1, sigma1, mean2, sigma2):
    weight = (sigma1[:,:,None]**2)*(sigma2[:,None,:]**2)/(sigma1[:,:,None]**2 + sigma2[:,None,:]**2)
    weight = weight + (weight*(mean1[:,:,None]/(sigma1[:,:,None]**2) + mean2[:,None,:]/(sigma2[:,None,:]**2)))**2

    return weight*normal_density(mean1, mean2, torch.sqrt(sigma1[:,:,None]**2+sigma2[:,None,:]**2))

def calculate_JT_gaussian_loss(density1, density2):
    '''
    Parameters :
    _ density1 (Tuple[Tensor(N,K)]) : p.d.f. of the first batch of gaussian mixtures
    _ density2 (Tuple[Tensor(N,K)]) : p.d.f. of the second batch of gaussian mixtures

    Returns :
    _ loss (Tensor(1)) : Jensen-Tsallis divergence between the two batches of GMMs
    '''

    a11 = density1[0][:,:,None]*density1[0][:,None,:]*normal_density(density1[1], density1[1], torch.sqrt(density1[2][:,:,None]**2+density1[2][:,None,:]**2))
    a22 = density2[0][:,:,None]*density2[0][:,None,:]*normal_density(density2[1], density2[1], torch.sqrt(density2[2][:,:,None]**2+density2[2][:,None,:]**2))
    a12 = density1[0][:,:,None]*density2[0][:,None,:]*normal_density(density1[1], density2[1], torch.sqrt(density1[2][:,:,None]**2+density2[2][:,None,:]**2))

    return torch.mean(torch.sum(torch.sum(a11+a22-2*a12, dim=1), dim=1))

def calculate_JTxpow2_gaussian_loss(density1, density2):
    '''
    Parameters :
    _ density1 (Tuple[Tensor(N,K)]) : p.d.f. of the first batch of gaussian mixtures
    _ density2 (Tuple[Tensor(N,K)]) : p.d.f. of the second batch of gaussian mixtures

    Returns :
    _ loss (Tensor(1)) : x^2-Jensen-Tsallis divergence between the two batches of GMMs
    '''

    a11 = density1[0][:,:,None]*density1[0][:,None,:]*weighted_normal(density1[1], density1[2], density1[1], density1[2])
    a22 = density2[0][:,:,None]*density2[0][:,None,:]*weighted_normal(density2[1], density2[2], density2[1], density2[2])
    a12 = density1[0][:,:,None]*density2[0][:,None,:]*weighted_normal(density1[1], density1[2], density2[1], density2[2])

    return torch.mean(torch.sum(torch.sum(a11+a22-2*a12, dim=1), dim=1))

def sampleGMM(density, n):
    """
    Samples from the GM model.
    Parameters :
    _ n (int) : number of samples
    
    Returns:
    _ samples (Tensor[N, n])
    """
    pi, mu, sigma = density
    batch_size, n_components = pi.shape

    counts = torch.distributions.multinomial.Multinomial(total_count=n, probs=pi).sample()
    sumcounts = torch.cat((torch.zeros(batch_size, 1), torch.cumsum(counts, dim=1), n*torch.ones(batch_size, 1)), dim=1)

    #x = mu[:,:,None] + torch.randn((batch_size, n_components, int(torch.max(counts)))) * sigma[:,:,None]

    samples = torch.empty((batch_size, n))

    for k in np.arange(n_components):
        for i in range(batch_size):
            samples[i,int(sumcounts[i, k]):int(sumcounts[i, k+1])] = mu[i,k] + torch.randn(int(counts[i,k])) * sigma[i,k]
    return samples

def calculate_MMD_gaussian_loss(density1, density2, theta, gamma):
    """
    Parameters :
    _ density1 (Tuple(Tensor[N,K])) : density of the first batch of gaussian mixtures
    _ density2 (Tuple(Tensor[N,K])) : density of the second batch of gaussian mixtures
    _ gammas : hyperparameters of the kernels
    _ theta : the mixture kernel used to compute the MMD is $k = \theta_0 k_{\rm lap} + \theta_1 k_{\rm mix rbf} + \theta_2 k_{\rm en}$

    Returns :
    The MMD distance between density1 and density2.
    """

    def U(x, sym=False):
        if(sym):
            return 2*torch.nn.functional.gelu(x)-x + 2*torch.exp(-x**2/2)/np.sqrt(2*np.pi)
        else:
            return torch.nn.functional.gelu(x) + torch.exp(-x**2/2)/np.sqrt(2*np.pi)
        
    def G(x,y):
        normal_dist = torch.distributions.normal.Normal(0., 1.)
        return torch.exp((y**2/2)-x)*normal_dist.cdf((x/y)-y) + torch.exp((y**2/2)+x)*normal_dist.cdf((-x/y)-y)
        
    def k_tot(p1, p2, m1, m2, sig1, sig2, sym=False):
        sigmatot = torch.sqrt((sig1**2)[:,:,None]+(sig2**2)[:,None,:])
        mutot = m1[:,:,None] - m2[:,None,:]

        k_loss = torch.zeros_like(mutot)
        if(theta[0] > 0):
            k_loss += (1+(sym==True))*theta[0]*G(mutot*gamma[0], sigmatot*gamma[0])

        if(theta[1] > 0):
            gammas = torch.sqrt(torch.arange(1, gamma[1]+1).to(sigmatot.device)/2)
            gammasig = gammas[:,None, None, None]**2+sigmatot[None, :, :, :]**2
            k_loss += (1+(sym==True))*theta[1]*torch.sum((gammas[:,None, None, None]/torch.sqrt(gammasig))*torch.exp(-mutot[None,:, :, :]**2/(2*gammasig)), dim=0)

        if(theta[2] > 0):
            k_loss += -theta[2]*sigmatot*torch.nan_to_num(U(mutot/sigmatot, sym=sym))

        return torch.sum(torch.sum(p1[:,:,None]*p2[:,None,:]*k_loss, dim=2), dim=1)

    prop1, mu1, sigma1 = density1
    prop2, mu2, sigma2 = density2

    MMD_loss = torch.mean(k_tot(prop1, prop1, mu1, mu1, sigma1, sigma1) + k_tot(prop2, prop2, mu2, mu2, sigma2, sigma2) - k_tot(prop1, prop2, mu1, mu2, sigma1, sigma2, sym=True))

    return MMD_loss

def calculate_MW2_loss(density1, density2):
    """
    Parameters :
    _ density1 (Tuple(Tensor[N,K])) : density of the first batch of gaussian mixtures
    _ density2 (Tuple(Tensor[N,K])) : density of the second batch of gaussian mixtures

    Returns :
    The MW2 distance between density1 and density2.
    """
    props1, means1, stds1 = density1
    props2, means2, stds2 = density2
    cost_matrix = (means1[:,:,None] - means2[:,None,:])**2 + (stds1[:,:,None]-stds2[:,None,:])**2
    
    res = torch.zeros(1).to(props1.device)

    for i in range(props1.shape[0]):
        res += emd2(props1[i], props2[i], cost_matrix[i])

    return res/props1.shape[0]

def mse_gauss(value, density):
    """
    Calculate the MSE associated to the gaussian density (to compute the MLE for instance)
    """
    props, mu, sigma = density
    diffmean = value[:,:,None] - mu
    return torch.mean(torch.sum(torch.sum(props*(sigma**2 + diffmean**2), dim=1), dim=1))

def calculate_SSBU_loss(density1, density2, N, beta):
    """
    Sample-Set Bellman Update based loss to replicate the Zhang 2024 paper results (WORK IN PROGRESS)
    """
    with torch.no_grad(): # maybe to delete
        # Sampling
        samples = sampleGMM(density1, int(N*(1.0-beta)))
        target_samples = sampleGMM(density2, int(N*beta))
        sample_set_mix = torch.cat((samples, target_samples), dim=1)

        # E step
        p = density1[2][:,:,None]*normal_density(sample_set_mix, density1[1], density1[2], dim=3)
        gamma_terms = p/torch.sum(p, dim=1)[:,None,:] # (batch_size, K, N)

        # M step
        sum_gamma_terms = torch.sum(gamma_terms, dim=2)
        print("nonzerogamma ", sum_gamma_terms.count_nonzero())

        hat_mu = (sum_gamma_terms > 0)*torch.nan_to_num(torch.sum(gamma_terms*sample_set_mix[:,None,:], dim=2)/sum_gamma_terms) + (sum_gamma_terms == 0)*density2[1]

        #print("hat_mu ", hat_mu.isnan().any())
        hat_sigma = torch.sqrt((sum_gamma_terms > 0)*torch.nan_to_num(torch.sum(gamma_terms*(sample_set_mix[:,None,:]-hat_mu[:,:,None])**2, dim=2)/sum_gamma_terms)) + (sum_gamma_terms == 0)*density2[2]

        #print("hat_sigma : ", hat_sigma.isnan().any())

        hat_pi = torch.mean(gamma_terms, dim=2)

        #print("hat_pi : ", hat_pi.isnan().any())

        #TODO : sorting by column to get the sorted modes for MSE comparison

    # MSE Loss
    return torch.sum((density1[1]-hat_mu)**2 + (density1[2] - hat_sigma)**2 + (density1[0] - hat_pi)**2)

def sum_mixtures(beta, density1, density2):
    """
    Sum two mixtures (WORK IN PROGRESS)
    """
    return [torch.concat(((1-beta)*density1[i], beta*density2[i]), dim=1) for i in range(len(density1))]

class RunningMeanStats:
    """
    Running mean of something
    """

    def __init__(self, n=10):
        self.n = n
        self.stats = deque(maxlen=n)

    def append(self, x):
        self.stats.append(x)

    def get(self):
        return np.mean(self.stats)


class LinearAnneaer:
    """
    Linear scheduler depending on the steps
    """

    def __init__(self, start_value, end_value, num_steps):
        assert num_steps > 0 and isinstance(num_steps, int)

        self.steps = 0
        self.start_value = start_value
        self.end_value = end_value
        self.num_steps = num_steps

        self.a = (self.end_value - self.start_value) / self.num_steps
        self.b = self.start_value

    def step(self):
        self.steps = min(self.num_steps, self.steps + 1)

    def get(self):
        assert 0 < self.steps <= self.num_steps
        return self.a * self.steps + self.b
