import torch
import random

normal = torch.distributions.Normal(0, 1)

def rho2sigma(rho):
    return torch.log(1+torch.exp(rho))

def sigma2rho(sigma):
    return torch.log(torch.exp(sigma)-1)

def sample_normal(mu, sigma):
    eps = normal.sample(mu.shape).to(mu.device)
    return mu + (sigma*eps)

def KL_normal(mu0, mu1, sigma0, sigma1):
    ratio = sigma0/sigma1
    return 0.5*(ratio**2 + ((mu0-mu1)**2)/(sigma1**2) -1 - 2*torch.log(ratio))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, state):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = state
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
    
class Identity():
    def __init__(self):
        pass
    
    def fit(self, x):
        """
        x : torch tesor shape of (N, ...) where N is the number of data
        """
        pass
        
    def transform(self, x) :
        return x
    
    def transform_inv(self, x):
        return x
    
    
class Standardization():
    def __init__(self):
        pass
    
    def fit(self, x):
        """
        x : torch tesor shape of (N, ...) where N is the number of data
        """
        self.mean = torch.sum(x, dim=0)/x.size(0)
        self.std = (x-self.mean)**2
        self.std = torch.sum(self.std, dim=0)
        self.std = self.std/(x.size(0)-1)
        
    def transform(self, x) :
        return (x-self.mean.to(x.device))/self.std.to(x.device)
    
    def transform_inv(self, x):
        return x*self.std.to(x.device) + self.mean.to(x.device)
    
class MinMax_Normalizer():
    def __init__(self):
        pass
    
    def fit(self, x):
        """
        x : torch tesor shape of (N, ...) where N is the number of data
        """
        self.max = torch.max(x, dim=0)[0]
        self.min = torch.min(x, dim=0)[0]
        
    def transform(self, x):
        return (x - self.min.to(x.device))/(self.max.to(x.device)-self.min.to(x.device))
    
    def transform_inv(self, x):
        return x*(self.max.to(x.device)-self.min.to(x.device)) + self.min.to(x.device)