import numpy as np
import torch
from numpy.lib.type_check import nan_to_num
from utils.wavelets import reconstruction_fields
from numpy.random import randint
from sklearn.neighbors import KernelDensity


def phi4_energy(x, beta = 0.68):
    """
    (N,C,4,4) to (N,)
    Not normalized with the partition constant.
    Default beta is the 2d-critical temperature.
    """
    quartic_interaction = torch.sum(x**4 - (1 + 2*beta)*x**4, dim=(-3, -2, -1))
    gaussian_interaction = torch.sum( (x -  torch.roll(x, 1, -2))**2 + (x - torch.roll(x, 1, -1))**2, dim=(-3,-2,-1))
    return -beta * gaussian_interaction + quartic_interaction


class PotentialAnsatz:

    def __init__(self, m, eta) -> None:
        """
        args:
        - m = number of splines on the positive axis.
        - eta = step size for box splines
        The splines are symmetrized.

        """
        self.m = m
        self.eta = eta
        self.relu = torch.nn.ReLU()
        
    def rho(self, x):
        """box spline with size eta"""
        return 2*self.relu(x) - self.relu(x-self.eta) - self.relu(x+self.eta)

    def __call__(self, phi):
        """
        (n,1,L,L) -> (n,2m+1)
        output[i,t] = sum v(input[i, 0, :, :])
        Performs computations in the device of phi.
        """
        n = phi.shape[0]
        ret = torch.zeros((n,self.m), device=phi.device)
        ret[:,0] = torch.sum(self.rho(phi).reshape(n,-1), axis=1)
        for i in range(1,self.m):
            aux = self.rho(phi + i*self.eta) + self.rho(phi - i*self.eta)
            ret[:,i] = torch.sum(aux.reshape(n,-1), axis=1)
        return ret


#################################################################
#                      Sampler class
#################################################################

class Sampler:
    """
    Sampler for the MCMC propositions.
    The wavelet bank is automatically put on the CPU. The method self.to(device)
    puts the bank to the specified device. 
    """
    def __init__(self, L, gamma = 1.):
        self.L = L
        self.gamma = gamma
        l = int(self.L/2)
        self.w = 3*l*l 
        decomposed_wavelets = np.zeros((self.w, 4, l, l))
        self.device = torch.device("cpu")

        count = 0
        for chan in range(1,4):
            for i in range(0,l):
                for j in range(0,l):
                    decomposed_wavelets[count, chan, i, j] = 1.
                    count += 1                    
        self.bank = torch.from_numpy(
            reconstruction_fields(decomposed_wavelets, self.gamma)
        )

    def to(self, device):
        self.bank = self.bank.to(device)
        self.device = device

    def generate(self, n, sigma=1):
        """
        Generates a batch of n random wavelets chosen uniformly among all the 3*(L/2)*(L/2)
        wavelets with high frequencies, weighted by iid N(0,nu^2).

        input:
        - n = number of samples
        - sigma = std of the normal weights. Default 1.
        - device = device on which the output array is. Default self.device.
        
        output : 
        - array(n,1,L,L)
        """
        device = self.device
        ret = torch.zeros((n, 1, self.L, self.L), device=device)
        idxs = torch.randint(self.w, size=(n,), device=device)
        weights = sigma * torch.randn(n, dtype=torch.float64, device=device) #gaussian weights with variances nu^2.
        for s in range(n):
            ret[s,:,:,:] += weights[s] * self.bank[idxs[s],:,:] 
        return ret


#######################################################
#          definition of the acceptance class         #
#######################################################


class FastUpdate:
    def __init__(self):
        # compute fft, statcov, V, quad_energy
        # stores them
        self.phi_fft = 0
        self.potential = 0
        self.quadenergy = 0
        self.statcov = 0
        
    def __call__(self):
        pass

    


#######################################################
#          definition of the hamiltonian class        #
#######################################################



class ElementaryHamiltonian:
    """
    This class is used for the estimation of a 1d density from samples. 
    It is typically used at the coarsest scale of the multiscale algorithm. 
    The method is Sklearn's  KernelDensity estimator.
    """

    def __init__(self):
        self.kernel = None

    def fit(self,x):
        """
        input: array (n,1,1,1)
        Fits the Gaussian Kernel Density estimator 
        using Scott's rule for bandwidth, ie b = 1.06*std/n^{1.5}. 
        """
        dims = x.shape
        assert dims[1:]==(1,1,1), f"The input must have dimension (n,1,1,1), currently {dims}."
        bandwidth = 1.06 * np.std(x) / (dims[0]**(1.5)) # Scott's rule
        self.kernel = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
        self.kernel.fit(x.reshape(dims[0], 1))

    def sample(self,n):
        assert self.kernel is not None, "Kernel is not fitted to data yet. "
        return self.kernel.sample(n).reshape(n,1,1,1)


class Hamiltonian:
    """
    attributes: j=scale and m=dimension of the potential ansatz approximation
    """

    def __init__(self, j, m, eta) -> None:
        """
        - j is the scale. The Hamiltonian will fit fields of size (L,L) with L=2^j.
        - m is the dimension of the potential ansatz V
        - eta is the chosen mesh size for the potential ansatz V
        """
        self.j = j
        self.L = int(2**j)
        self.m = m # dimension of the linear space for the approximation of the potential
        self.eta = eta #mesh size for the potential approximation
        self.V = PotentialAnsatz(m,eta)
        self.gamma = 1. # normalization constant between scales
        self.device = torch.device("cpu")

        self.C = torch.zeros((self.L,self.L)) # this is Delta K
        self.C[0,0] = 1. # no convolution at init
        self.θ = torch.zeros((1,m))#for the moment it's zero, so the hamiltonian is Gaussian


    def to(self, device):
        self.θ = self.θ.to(device)
        self.C = self.C.to(device)
        self.device = device

    def sample(self, 
        n_mcmc, 
        init, 
        sampler,
        nu=0.1, 
        compute_moments=True):
        """
        
        """
        if compute_moments==True:
            covariance = torch.zeros_like(self.C)
            moment = torch.zeros((1, self.m), device=self.device)

        dims = init.shape
        n,l = dims[0], dims[-1]
        assert l==int(self.L), f"The init fields must have dim (n, 1, {self.L}, {self.L}), currently have dim {dims}."
        assert dims[1]==1, f"The init fields must have dim (n, 1, {self.L}, {self.L}), currently have dim {dims}."

        C_fft = torch.fft.fft2(self.C)

        current_state = torch.clone(init)
        current_fft_squared = torch.abs(torch.fft.fft2(current_state[:,0,:,:]))**2
        current_energy = torch.real(torch.sum(torch.multiply(C_fft, current_fft_squared)/l**2, axis=(1,2)))
        current_potential = self.V(current_state)

        acc = torch.zeros(n)

        for epoch in range(n_mcmc):
            
            prop = current_state + sampler.generate(n, nu)
            prop_fft_squared = torch.abs(torch.fft.fft2(prop[:,0,:,:]))**2
            prop_energy = torch.real(torch.sum(torch.multiply(C_fft, prop_fft_squared)/l**2, axis=(1,2)))
            prop_potential = self.V(prop)
            delta_energy = 0.5 * (current_energy - prop_energy)
            delta_moment = torch.sum(torch.multiply(self.θ, current_potential - prop_potential), axis=1)
            accept = ( delta_energy + delta_moment > torch.log(torch.rand(n, device=self.device)) )

            #update the states
            current_state[accept] = prop[accept]
            current_fft_squared[accept] = prop_fft_squared[accept]
            current_energy[accept] = prop_energy[accept]
            current_potential[accept] = prop_potential[accept]

            acc[accept] += 1

            if compute_moments==True:
                moment += torch.mean(current_potential, axis=0)
                covariance += torch.mean(torch.real(torch.fft.ifft2(current_fft_squared))/l**2, axis=0)
        
        if compute_moments==True:
            return covariance / n_mcmc, moment / n_mcmc, acc
        else:
            return current_state, acc





#######################################################
#                 Loading models
#######################################################

def load_hamiltonians(path = 'save/models/ham.npz'):
    hlist = np.load(path, allow_pickle=True)
    return hlist