import functools
import torch
import torch.nn as nn
import numpy as np
import tqdm.notebook
import random
import tqdm

from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from scipy import integrate


def drift_coeff(x, t, beta_1, beta_0):
    # t = torch.tensor(t)
    t = t.clone().detach()
    beta_t = beta_0 + t * (beta_1 - beta_0)
    drift = -0.5 * beta_t * x
    return drift

# g(t)
def diffusion_coeff(t, beta_1, beta_0):
    # t = torch.tensor(t)
    t = t.clone().detach()
    beta_t = beta_0 + t * (beta_1 - beta_0)
    diffusion = torch.sqrt(beta_t)
    return diffusion

drift_coeff_fn = functools.partial(drift_coeff, beta_1=20, beta_0=0.1)
diffusion_coeff_fn = functools.partial(diffusion_coeff, beta_1=20, beta_0=0.1)
# marginal_prob_mean_fn = functools.partial(marginal_prob_mean, beta_1=20, beta_0=0.1)
# marginal_prob_std_fn = functools.partial(marginal_prob_std, beta_1=20, beta_0=0.1)


def Euler_Maruyama_sampling(model, T, N, P, device):
    P = P // 2
    time_steps = torch.linspace(1., 1e-5, T) 
    step_size = time_steps[0] - time_steps[1] 

    Gen_data = torch.empty(N, P)

    init_x = torch.randn(N, P)
    X = init_x.to(device)
    
    # tqdm_epoch = tqdm.notebook.trange(T)
    tqdm_epoch = range(T)
    
    with torch.no_grad():
        for epoch in tqdm_epoch:
            time_step = time_steps[epoch].unsqueeze(0).to(device)
            # Predictor step (Euler-Maruyama)
            f = drift_coeff_fn(X, time_step).to(device)
            g = diffusion_coeff_fn(time_step).to(device)
            X = X - ( f - (g**2) * ( model(X, time_step) )  ) * step_size.to(device) + torch.sqrt(step_size).to(device)*g*torch.randn_like(X).to(device)

    Gen_data = X.cpu()
    
    return Gen_data.to(device)


def gaussian_kl_divergence(policy_mean, policy_std, ref_mean, ref_std):
    ref_var = ref_std.pow(2)
    policy_var = policy_std.pow(2)
    
    kl = torch.sum(
        torch.log(ref_std / policy_std) + 
        (policy_var + (policy_mean - ref_mean).pow(2)) / (2 * ref_var) - 0.5,
        dim=-1
    )
    return kl

def Euler_Maruyama_sampling_for_RL(model, ref_model, T, N, P, device):
    P = P // 2
    time_steps = torch.linspace(1., 1e-5, T) 
    step_size = time_steps[0] - time_steps[1] 

    Gen_data = torch.empty(N, P)

    init_x = torch.randn(N, P)
    X = init_x.to(device)

    tqdm_epoch = range(T)
    
    x_seqs = []
    log_probs = []
    entropys = []
    kl_divs = []
    for epoch in tqdm_epoch:
        time_step = time_steps[epoch].unsqueeze(0).to(device)
        # Predictor step (Euler-Maruyama)
        f = drift_coeff_fn(X, time_step).to(device)
        g = diffusion_coeff_fn(time_step).to(device)
        mean = ( f - (g**2) * ( model(X, time_step) )  ) * step_size.to(device)
        std = torch.sqrt(step_size).to(device)*g
  
        ref_mean = ( f - (g**2) * ( ref_model(X, time_step) )  ) * step_size.to(device)
        ref_std = torch.sqrt(step_size).to(device)*g

        X = X - mean + std * torch.randn_like(X).to(device)
        
        dist = torch.distributions.Normal(mean, std)  
        log_prob = dist.log_prob(X)                    

        entropy = dist.entropy()
        entropy = entropy.mean(dim=list(range(1, len(log_prob.shape))))
        
        x_seqs.append(X)
        log_probs.append(log_prob)
        entropys.append(entropy)
        kl_divs.append(gaussian_kl_divergence(mean, std, ref_mean, ref_std))
    Gen_data = X.cpu()
    return Gen_data.to(device), x_seqs, log_probs, entropys, kl_divs

