import torch
import numpy as np
from scipy.stats import qmc
from scipy.stats import norm


def Gaussian_kernel(a,b,bandwidth):
    if isinstance(a, np.ndarray):
        a = torch.from_numpy(a)
    if isinstance(b, np.ndarray):
        b = torch.from_numpy(b)
    dist = torch.cdist(a, b, p=2)
    return np.exp(-((dist**2) / (2 * (bandwidth**2))))
    

class MCRF_Encoder:
    def __init__(self, input_dim, dim, bandwidth):
        self.Ex = torch.empty((input_dim, dim)).normal_(mean=0,std=(1/bandwidth)).float()
        self.input_dim = input_dim
        self.dim = dim
        self.bandwidth = bandwidth

    def encode_x(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        return (1/np.sqrt(self.dim)) * torch.exp(1j * (x @ self.Ex)).numpy()


    def similarity(self,a, b):
        if isinstance(a, np.ndarray):
            a = torch.from_numpy(a)
        if isinstance(b, np.ndarray):
            b = torch.from_numpy(b)
        return ((torch.conj(a) @ b.T).real).numpy()


    def bind(self,a, b):
        if isinstance(a, np.ndarray):
            a = torch.from_numpy(a)
        if isinstance(b, np.ndarray):
            b = torch.from_numpy(b)
        return (torch.mul(a,b)).numpy()




class QMCRF_Encoder:
    def __init__(self, input_dim, dim, bandwidth):
        
        sampler = qmc.Halton(d=input_dim, scramble=True)
        self.Ex = torch.from_numpy(sampler.random(n=dim)).T.float()
        
        self.Ex = norm.ppf(self.Ex) / bandwidth
        
        self.input_dim = input_dim
        self.dim = dim
        self.bandwidth = bandwidth

    def encode_x(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        return (1/np.sqrt(self.dim)) * torch.exp(1j * (x @ self.Ex)).numpy()



    def similarity(self,a, b):
        if isinstance(a, np.ndarray):
            a = torch.from_numpy(a)
        if isinstance(b, np.ndarray):
            b = torch.from_numpy(b)
        return ((torch.conj(a) @ b.T).real).numpy()


    def bind(self,a, b):
        if isinstance(a, np.ndarray):
            a = torch.from_numpy(a)
        if isinstance(b, np.ndarray):
            b = torch.from_numpy(b)
        return (torch.mul(a,b)).numpy()
