import os
import sys
import numpy as np
import matplotlib.pyplot as plt
# import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json

import datetime

import sys
sys.path.append("")
# print(sys.path)

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm
from src.utils.sampling import model_based_langevin_monte_carlo
from src.doob_h.doob_h import doob_h, doob_h_tensorized, compute_conditional_expectation
from src.models.nn_potential import ApproxPotential

def main():
    # モデルのパスをconfigから読み込む
    config_doob_h_path = 'configs/doob_h.json'
    with open(config_doob_h_path) as f:
        config_doob_h = json.load(f)
    sbm_path = config_doob_h["sbm_path"]
    potential_path = config_doob_h["potential_path"]
    sbm = ScoreBaseModel(2)
    sbm.load_state_dict(torch.load(sbm_path))
    sbm.eval()
    model_potential = ApproxPotential(2)
    model_potential.load_state_dict(torch.load(potential_path))
    
    sbm = nn.DataParallel(sbm)
    model_potential = nn.DataParallel(model_potential)
    sbm.cuda()
    model_potential.cuda()

    model_potential.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if 1 == 0:
        # 初期化
        x = torch.tensor([0.5,0.5], requires_grad=True).to(device)
        t = 0.2
        step_size = 0.1

        for _ in range(10):
            # numerator_doob_h の計算
            sum_x, sum_y = doob_h(sbm, model_potential, x, t, step_size)
            # 結果の出力
            print("doob_h output:")
            print("sum_x:", sum_x.item())
            print("sum_y:", sum_y.item())
        # model_potential の勾配の計算
        model_potential.eval()  # 評価モードにする
        output = model_potential(x.unsqueeze(0))
        # output.backward()  # 勾配を計算
        # grad_model_potential = x.grad
        grad_model_potential = torch.autograd.grad(output, x, retain_graph=True)[0]

        print("\nGradient of model_potential:")
        print(-grad_model_potential)

    print("-"*50)
    x_t = torch.tensor([[-2.5, 0.0], [2.5, 0.0]], requires_grad=True).to(device)
    t_list = [0.1, 0.2]
    for t in t_list:
        print("t: ", t)
        expectation = compute_conditional_expectation(sbm, model_potential, x_t, t, num_steps=1, T=10, num_samples=100, device='cuda')
        print("expectation: ", expectation)
        print("real potential: ", torch.exp(-model_potential(x_t)))
    
    print("-"*50)

    # 推定された密度関数を評価するためのグリッドを作成
    x_grid = np.linspace(-5, 5, 15)
    y_grid = np.linspace(-5, 5, 15)
    X, Y = np.meshgrid(x_grid, y_grid)
    positions = np.vstack([X.ravel(), Y.ravel()])

    # ベクトル場を格納するリスト
    u = []
    v = []

    # 各グリッドポイントでdoob_hを計算
    x = torch.tensor([positions[:,i] for i in range(len(positions[0]))], requires_grad=True).to(device)
    print("x.shape: ", x.shape)
    t = 0.2
    step_size = 0.1
    h = doob_h_tensorized(sbm, model_potential, torch.tensor(x), t, step_size)
    
    for i in range(len(positions[0])):
        h_x = h[0][i]
        h_y = h[1][i]
        u.append(h_x.cpu().detach().numpy())
        v.append(h_y.cpu().detach().numpy())

    # for i in tqdm(range(positions.shape[1])):
    #     x = positions[:, i]
    #     h_x, h_y = doob_h(sbm, model_potential, torch.tensor(x), 0.1, 0.1)
    #     u.append(h_x.cpu().detach().numpy())
    #     v.append(h_y.cpu().detach().numpy())

    # ベクトル場を可視化
    plt.figure(figsize=(8, 6))
    U = np.array(u).reshape(X.shape)
    V = np.array(v).reshape(Y.shape)
    # plt.quiver(X, Y, np.array(u).reshape(X.shape), np.array(v).reshape(Y.shape))
    # 色について, ベクトルの大きさが大きいほど色が濃くなるように設定
    
    scale = 0.5
    plt.quiver(X,Y, scale * U/np.sqrt(pow(U,2)+pow(V,2)), scale * V/np.sqrt(pow(U,2)+pow(V,2)), np.sqrt(pow(U,2)+pow(V,2)), cmap='jet', width=0.03)
    # 背景をグレーに
    plt.gca().set_facecolor('black')
    plt.colorbar(label='Magnitude')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Doob\'s h Vector Field')
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    # 日付をstrで取得
    # ファイル名に日付を含める
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    # ディレクトリが存在しない場合は作成
    dir = os.path.join('outputs', 'figures', 'doob_h', now_str)
    os.makedirs(dir, exist_ok=True)
    # dir と ファイル名を結合
    filename = os.path.join(dir, 'doob_h_vector_field.png')
    plt.savefig(filename)

    # potential_pathの名前をtxtで保存
    with open(os.path.join(dir, 'potential_path.txt'), mode='w') as f:
        f.write(potential_path)

    print("### visualize_doob_h_vector.py finished. ###")

if __name__ == "__main__":
    main()