import torch
import gc

def compute_conditional_expectation(model, model_potential, x_t, t, num_steps=100, T=10, num_samples=100, device='cuda'):
    """
    時刻tにおけるxの値を入力として, DDPMを時刻0
    まで動かして条件付き期待値を計算する関数

    Args:
        model: 学習済みScore-Based Model
        model_potential: 任意の関数 (torch.nn.Module)
        x_t: 時刻tにおけるxの値 torch.tensor[(x,y),...] (batch_size, input_dim)
        t: 時刻t
        num_steps: サンプリングステップ数
        T: 最後の時刻
        num_samples: 各行に対して取るサンプル数
    Returns:
        x_0: 時刻0におけるxの条件付き期待値 (batch_size, )
    """
    step_size = t / num_steps
    batch_size = x_t.shape[0]

    # 推論モード
    model.eval()
    model_potential.eval()

    def model_exp(x):
        return torch.exp(-model_potential(x))

    # x_t を num_samples 回複製 (shape: [batch_size * num_samples, input_dim])
    x_t_repeated = x_t.repeat_interleave(num_samples, dim=0).to(device) # (x_t[0], x_t[0], ... x_t[0], x_t[1], x_t[1], ..., x_t[1], ...)
    assert (x_t_repeated[0][0] == x_t_repeated[num_samples-1][0]) & (x_t_repeated[0][1] == x_t_repeated[num_samples-1][1])

    # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
    for i in range(num_steps):
        with torch.no_grad():
            t_tmp = t - t / num_steps * i
            if t_tmp > 0:
                batch_t = torch.ones((batch_size * num_samples, 1)).to(device) * t_tmp
                score = model(torch.cat([x_t_repeated, batch_t], axis=1))
                # 最終ステップのみノイズ無しでスコアの方向に更新
                if i < num_steps - 1:
                    noise = torch.randn_like(x_t_repeated).to(device)
                else:
                    noise = torch.tensor(0).to(device)
                    # noiseをx_t_repeatedと同じshapeにする
                    noise = noise.repeat(batch_size * num_samples, 2)
                x_t_repeated = x_t_repeated + (x_t_repeated + 2 * score) * step_size + torch.sqrt(2 * torch.tensor(step_size)) * noise

                del batch_t, score, noise
                gc.collect()
                torch.cuda.empty_cache()

    del x_t
    gc.collect()
    torch.cuda.empty_cache()

    # 期待値の計算 (各行ごとにサンプル平均を取る)
    expectations = model_exp(x_t_repeated).squeeze(1)  # (batch_size * num_samples, )

    del x_t_repeated
    gc.collect()
    torch.cuda.empty_cache()

    expectations = expectations.view(batch_size, num_samples)  # (batch_size, num_samples)
    expectations = torch.mean(expectations, dim=1)  # (batch_size, )

    return expectations

def doob_h(model, model_potential, x, t, step_size, calc_grad=False):
    """
        x: requires_grad = True, (input_dim,)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_samples = 300 #0

    model_potential.eval()

    x = x.requires_grad_(True).to(device)

    if calc_grad:
        model.train()
        x = x.requires_grad_(True).unsqueeze(0)
        t_gpu = torch.ones((x.shape[0], 1), requires_grad = False, dtype=x.dtype).to(device) * t
        score = model(torch.cat([x.float(), t_gpu], axis=1)) ## concat x and t

        gradients_first = torch.autograd.grad(score[0, 0], x, retain_graph=True)[0][0] # 第一成分の微分
        gradients_second = torch.autograd.grad(score[0, 1], x)[0][0] # 第二成分の微分
        print("gradients_first.shape: {grad1}, gradients_second.shape: {grad2}".format(grad1=gradients_first.shape, grad2=gradients_second.shape))
    else:
        model.eval()
        x = x.requires_grad_(False).unsqueeze(0)
        t_gpu = torch.ones((x.shape[0], 1), requires_grad = False, dtype=x.dtype).to(device) * t
        with torch.no_grad():
            score = model(torch.cat([x.float(), t_gpu], axis=1))
        gradients_first = torch.zeros_like(x)
        gradients_second = torch.zeros_like(x)

    model.eval()

    # Convert step_size to a tensor
    step_size_tensor = torch.tensor(step_size, device=device)

    mean = torch.exp(step_size_tensor) * x + (-2 * score) * (1-torch.exp(-step_size_tensor)) ## torch.tensor([?,?])
    std = torch.sqrt(torch.exp(2 * step_size_tensor)-1)

    ## x_tを正規分布からサンプル
    x_t = torch.randn(num_samples, model.input_dim).to(device) * std + mean

    if calc_grad:
        deriv_mean_x_by_x = torch.exp(step_size_tensor) + (-2 * gradients_first[0]) * (1-torch.exp(-step_size_tensor))
        deriv_mean_x_by_y = (-2 * gradients_first[1]) * (1-torch.exp(-step_size_tensor))
        deriv_mean_y_by_x = (-2 * gradients_second[0]) * (1-torch.exp(-step_size_tensor))
        deriv_mean_y_by_y = torch.exp(step_size_tensor) + (-2 * gradients_second[1]) * (1-torch.exp(-step_size_tensor))
    else:
        deriv_mean_x_by_x = torch.exp(step_size_tensor) * torch.ones_like(x)
        deriv_mean_x_by_y = torch.zeros_like(deriv_mean_x_by_x)
        deriv_mean_y_by_x = torch.zeros_like(deriv_mean_x_by_x)
        deriv_mean_y_by_y = torch.exp(step_size_tensor) * torch.ones_like(x)

    del gradients_first, gradients_second
    gc.collect()
    torch.cuda.empty_cache()

    # 計算をベクトル化
    mean_diff_x = deriv_mean_x_by_x * (x_t[:, 0] - mean[0][0]) / std**2 \
                + deriv_mean_y_by_x * (x_t[:, 1] - mean[0][1]) / std**2
    mean_diff_y = deriv_mean_x_by_y * (x_t[:, 0] - mean[0][0]) / std**2 \
                + deriv_mean_y_by_y * (x_t[:, 1] - mean[0][1]) / std**2

    # 期待値の計算をベクトル化 (ただし、compute_conditional_expectation 内部のループは残る)
    # expectations = torch.stack([compute_conditional_expectation(model, exp_potential, x_t[i], max(t-step_size,0)) for i in range(x_t.shape[0])])
    t_e = max(t-step_size,0)
    expectations = compute_conditional_expectation(model, model_potential, x_t, t_e, num_steps=t_e*10)
    # print(expectations)

    sum_x = torch.sum(mean_diff_x * expectations)
    sum_y = torch.sum(mean_diff_y * expectations)

    denom = torch.sum(expectations)

    del deriv_mean_x_by_x, deriv_mean_x_by_y, deriv_mean_y_by_x, deriv_mean_y_by_y, expectations, x_t
    gc.collect()
    torch.cuda.empty_cache()

    eps = 1e-10

    return sum_x / (denom + eps), sum_y / (denom + eps)

def doob_h_tensorized(model, model_potential, x, t, step_size, calc_grad=False):
    """
    x: (batch_size, input_dim)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_samples = 300 #0
    # doob_hを(batch_size, input_dim)のxに対して計算
    model_potential.eval()
    model.train()
    batch_size = x.shape[0]

    input_dim = 2

    x = x.requires_grad_(True).to(device)
    # x = x.unsqueeze(0)

    t_gpu = torch.ones((x.shape[0], 1), requires_grad = False, dtype=x.dtype).to(device) * t
    # t_gpu = t_gpu.squeeze(1)


    score = model(torch.cat([x.float(), t_gpu], axis=1)) ## concat x and t

    # gradients = torch.autograd.grad(score, x, retain_graph=True)[0] # (batch_size, input_dim)
    
    gradients_first = []
    gradients_second = []
    for i in range(score.size(0)):
        grad_outputs = torch.ones_like(score[i][0])
        if calc_grad:
            grad = torch.autograd.grad(score[i][0], x, grad_outputs=grad_outputs, retain_graph=True)[0][0]
        else:
            grad = torch.zeros_like(x[i])
        gradients_first.append(grad)
        if calc_grad:
            grad = torch.autograd.grad(score[i][1], x, grad_outputs=grad_outputs, retain_graph=True)[0][0]
        else:
            grad = torch.zeros_like(x[i])
        gradients_second.append(grad)
    gradients_first = torch.stack(gradients_first).detach()
    gradients_second = torch.stack(gradients_second).detach()
    if not (gradients_first.shape == (batch_size, input_dim) and gradients_second.shape == (batch_size, input_dim)):
        ValueError("gradients_first.shape: {grad1}, gradients_second.shape: {grad2}".format(grad1=gradients_first.shape, grad2=gradients_second.shape))

    score = score.detach()

    model.eval()

    # Convert step_size to a tensor
    step_size_tensor = torch.tensor(step_size, device=device)

    mean = torch.exp(step_size_tensor) * x + (-2 * score) * (1-torch.exp(-step_size_tensor)) ## torch.tensor([?,?])
    assert mean.shape == (batch_size, input_dim)
    std = torch.sqrt(torch.exp(2 * step_size_tensor)-1)
    if not std.shape == (1,):
        print("std.shape:", std.shape)

    ## x_tを正規分布からサンプル
    ## x_t: (batch_size, num_samples, input_dim)
    x_t_original = torch.randn(x.shape[0], num_samples, x.shape[1]).to(device) * std + mean.unsqueeze(1).repeat(1, num_samples, 1)

    # x_tを(batch_size * num_samples, input_dim)に変換
    x_t_reshaped = x_t_original.view(x.shape[0] * num_samples, input_dim)

    # 期待値の計算をベクトル化 (ただし、compute_conditional_expectation 内部のループは残る)
    expectations = compute_conditional_expectation(model, model_potential, x_t_reshaped, max(t-step_size,0))

    ## jacobian of mean
    deriv_mean_x_by_x = torch.exp(step_size_tensor).repeat(batch_size) + (-2 * gradients_first[:,0]) * (1-torch.exp(-step_size_tensor))
    deriv_mean_x_by_y = (-2 * gradients_first[:,1]) * (1-torch.exp(-step_size_tensor))
    deriv_mean_y_by_x = (-2 * gradients_second[:,0]) * (1-torch.exp(-step_size_tensor))
    deriv_mean_y_by_y = torch.exp(step_size_tensor).repeat(batch_size) + (-2 * gradients_second[:,1]) * (1-torch.exp(-step_size_tensor))

    del gradients_first, gradients_second
    gc.collect()
    torch.cuda.empty_cache()

    x_t = x_t_reshaped.view(x.shape[0], num_samples, input_dim)
    
    deriv_mean_x_by_x = deriv_mean_x_by_x.unsqueeze(1).repeat(1, num_samples)
    deriv_mean_x_by_y = deriv_mean_x_by_y.unsqueeze(1).repeat(1, num_samples)
    deriv_mean_y_by_x = deriv_mean_y_by_x.unsqueeze(1).repeat(1, num_samples)
    deriv_mean_y_by_y = deriv_mean_y_by_y.unsqueeze(1).repeat(1, num_samples)
    assert deriv_mean_x_by_x.shape == (batch_size, num_samples)
    assert deriv_mean_x_by_y.shape == (batch_size, num_samples)
    assert deriv_mean_y_by_x.shape == (batch_size, num_samples)
    assert deriv_mean_y_by_y.shape == (batch_size, num_samples)

    # 計算をベクトル化

    mean = mean.unsqueeze(1).repeat(1, x_t.shape[1], 1)

    mean_diff_x = deriv_mean_x_by_x * (x_t[:, :, 0] - mean[:, :, 0]) / std**2 \
                + deriv_mean_y_by_x * (x_t[:, :, 1] - mean[:, :, 1]) / std**2
    mean_diff_y = deriv_mean_x_by_y * (x_t[:, :, 0] - mean[:, :, 0]) / std**2 \
                + deriv_mean_y_by_y * (x_t[:, :, 1] - mean[:, :, 1]) / std**2
    assert mean_diff_x.shape == (batch_size, num_samples)
    assert mean_diff_y.shape == (batch_size, num_samples)

    expectations = expectations.reshape(x.shape[0], num_samples)  # (batch_size, num_samples)
    assert expectations.shape == (batch_size, num_samples)

    sum_x = torch.sum(mean_diff_x * expectations, dim=1)
    sum_y = torch.sum(mean_diff_y * expectations, dim=1)

    denom = torch.sum(expectations, dim=1)

    del deriv_mean_x_by_x, deriv_mean_y_by_y, expectations, x_t, score
    gc.collect()
    torch.cuda.empty_cache()

    eps = 1e-10

    return sum_x / (denom + eps), sum_y / (denom + eps)

