import torch
import numpy as np
from tqdm import tqdm
import gc

import sys
sys.path.append("/home/***/work/doob")


from src.doob_h.doob_h import doob_h_tensorized

# モデルベースのランジュバン・モンテカルロ法の実装
def model_based_langevin_monte_carlo(model, num_samples, num_steps, step_size, T=10, device='cuda'):
    # 初期サンプルを乱数から生成
    x = torch.randn(num_samples, model.input_dim).to(device)
    # モデルを推論モードに変更
    model.eval()
    # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
    for i in tqdm(range(num_steps)):
        with torch.no_grad():
            t = T - T / num_steps * i
            batch_t = torch.ones((x.shape[0], 1)).to(device) * t
            score =model(torch.cat([x, batch_t], axis=1)) # model(torch.concat([x, batch_t], axis=1))
            # 最終ステップのみノイズ無しでスコアの方向に更新
            if i < num_steps - 1:
                noise = torch.randn(num_samples, model.input_dim).to(device)
            else:
                noise = 0
            x = x + (x + 2 * score) * step_size + np.sqrt(2 * step_size) * noise
    return x

# モデルベースのランジュバン・モンテカルロ法の実装
def new_sbm_based_langevin_monte_carlo(model_sbm, num_samples, num_steps, T=10, device='cuda'):
    step_size = T / num_steps
    # input_dim = model_sbm.module.input_dim # 並列化のため
    try:
        input_dim = model_sbm.module.input_dim if hasattr(model_sbm, 'module') else model_sbm.input_dim
    except:
        input_dim = 2
        print("input_dim is set to 2")
    # 初期サンプルを乱数から生成
    x = torch.randn(num_samples, input_dim).to(device)
    # モデルを推論モードに変更
    model_sbm.eval()
    # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
    for i in tqdm(range(num_steps), leave=False):
        with torch.no_grad():
            t = T - T / num_steps * i
            batch_t = torch.ones((x.shape[0], 1)).to(device) * t
            score =model_sbm(torch.cat([x, batch_t], axis=1)) # model(torch.concat([x, batch_t], axis=1))
            # 最終ステップのみノイズ無しでスコアの方向に更新
            if i < num_steps - 1:
                noise = torch.randn(num_samples, input_dim).to(device)
            else:
                noise = 0
            x = x + (x + 2 * score) * step_size + np.sqrt(2 * step_size) * noise
    return x

def doob_langevin_monte_carlo_modified(model, model_potential, num_samples, num_steps, T=10, device = "cuda"):
    with_ref = False
    step_size = T / num_steps
    input_dim = model.module.input_dim #並列化のため
    # 初期サンプルを乱数から生成
    x = torch.randn(num_samples, input_dim).to(device)
    if with_ref:
        x_ref = x.clone()
    # モデルを推論モードに変更
    model.eval()
    model_potential.eval()
    # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
    for i in tqdm(range(num_steps), leave=False):
        t = T - T / num_steps * i ## t = T から, t = 0に行く
        batch_t = torch.ones((x.shape[0], 1)).to(device) * t
        ## x: (num_samples, input_dim)
        doob = doob_h_tensorized(model, model_potential, x, t, step_size)
        doob = torch.stack([doob[0], doob[1]], axis=1)
        
        with torch.no_grad():
            score = model(torch.cat([x, batch_t], axis=1))
            if with_ref:
                score_ref = model(torch.cat([x_ref, batch_t], axis=1))
            # 最終ステップのみノイズ無しでスコアの方向に更新
            if i < num_steps - 1:
                noise = torch.randn(num_samples, input_dim).to(device)
            else:
                noise = 0

            x     = x     + (x     + 2 * score + 2 * doob) * step_size + np.sqrt(2 * step_size) * noise
            if with_ref:
                x_ref = x_ref + (x_ref + 2 * score_ref)        * step_size + np.sqrt(2 * step_size) * noise
            # print("doob",doob)
            # print("score",score)
        del batch_t, doob, score, noise
        if with_ref:
            del score_ref
        gc.collect()
        torch.cuda.empty_cache()

    if with_ref:
        return x, x_ref
    else:
        return x


# モデルベースのランジュバン・モンテカルロ法の実装
def calc_density_ratio(x_input, aligned_sbm, model_sbm, num_samples, num_steps, T=10, device='cuda'):
    '''
    x: torch.tensor, shape=(batch_size, input_dim)
    '''
    step_size = T / num_steps
    # input_dim = model_sbm.module.input_dim # 並列化のため
    try:
        input_dim = model_sbm.module.input_dim if hasattr(model_sbm, 'module') else model_sbm.input_dim
    except:
        input_dim = 2
        print("input_dim is set to 2")
    # 初期サンプルを乱数から生成
    batch_size = x_input.shape[0]
    x = x_input.repeat_interleave(num_samples, dim=0)
    # モデルを推論モードに変更
    model_sbm.eval()
    aligned_sbm.eval()
    # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
    log_density_ratio_aligned = torch.zeros(num_samples*batch_size).to(device)
    log_density_ratio_ref = torch.zeros(num_samples*batch_size).to(device)
    log_density_ratio = torch.zeros(num_samples*batch_size).to(device)
    for i in tqdm(range(num_steps), leave=False):
        with torch.no_grad():
            t = T / num_steps * i
            batch_t = torch.ones((x.shape[0], 1)).to(device) * t
            score = model_sbm(torch.cat([x, batch_t], axis=1)) # model(torch.concat([x, batch_t], axis=1))
            aligned_score = aligned_sbm(torch.cat([x, batch_t], axis=1))
            score = score.squeeze()
            aligned_score = aligned_score.squeeze()
            # 最終ステップのみノイズ無しでスコアの方向に更新
            if i < num_steps - 1:
                noise = torch.randn(num_samples*batch_size, input_dim).to(device)
            else:
                noise = torch.zeros(num_samples*batch_size, input_dim).to(device)
            x = x + (-x) * step_size + np.sqrt(2 * step_size) * noise
            ## density ratio
            true_score = -(np.sqrt(1-np.exp(-2*t)+1e-10)**(-1)) * noise
            assert score.shape == aligned_score.shape and score.shape == true_score.shape
            # true score is not nan
            assert torch.isnan(true_score).sum() == 0
            beta = 1 - np.exp(-2 * step_size)
            alpha = np.exp(-2*step_size)
            alpha_bar = np.exp(-t)
            sigma = np.sqrt(1 - alpha)
            ## sigma^2 ~ beta, then beta offsets the variance asymtotically
            log_density_ratio_aligned = log_density_ratio_aligned + torch.clamp((- torch.square(torch.norm(aligned_score - true_score, dim=1))) * (beta / (2 * alpha + 1e-10)), min=-1, max=1)
            log_density_ratio_ref     = log_density_ratio_ref     + torch.clamp((- torch.square(torch.norm(score - true_score        , dim=1))) * (beta / (2 * alpha + 1e-10)), min=-1, max=1)
            log_density_ratio = log_density_ratio + torch.clamp((- torch.square(torch.norm(aligned_score - true_score, dim=1))\
                                                                 + torch.square(torch.norm(score - true_score        , dim=1)))\
                                                                      * (beta / (2 * alpha + 1e-10)), min=-10, max=10)
            # print("log_density_ratio", log_density_ratio)
    log_density_ratio_aligned = log_density_ratio_aligned.view(batch_size, num_samples)
    log_density_ratio_ref = log_density_ratio_ref.view(batch_size, num_samples)
    density_ratio_aligned = torch.exp(log_density_ratio_aligned)
    density_ratio_ref = torch.exp(log_density_ratio_ref)
    density_ratio_aligned_mean = torch.mean(density_ratio_aligned, dim=1)
    density_ratio_ref_mean = torch.mean(density_ratio_ref, dim=1)
    print("density_ratio_aligned_mean", density_ratio_aligned_mean)
    print("density_ratio_ref_mean", density_ratio_ref_mean)
    return density_ratio_aligned_mean / density_ratio_ref_mean