import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json

import sys
sys.path.append("/home/***/work/doob")
# print(sys.path)

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm
from src.utils.sampling import model_based_langevin_monte_carlo

def main():

    input_dim = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # モデルのインスタンスを作成
    sbm = ScoreBaseModel(input_dim).to(device)

    # 保存したモデルの状態をロード
    sbm.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))

    # モデルを評価モードに設定
    sbm.eval()

    # ランジュバン・モンテカルロ法のパラメータ
    num_samples = 10000
    num_steps = 100
    T = 10
    step_size = T / num_steps

    # サンプリングの実行
    ref_df = pd.read_csv("outputs/pretrain/samples_from_ref.csv")
    # ref_dfを, torch.tensorに変換
    ref_df = ref_df.to_numpy()
    samples_ref = torch.tensor(ref_df, dtype=torch.float32)
    # ref_dfの形は(20000, 2)である

    # サンプリングの実行
    # サンプリング結果の可視化
    fs = 15
    # Move the tensor to the CPU and convert it to a NumPy array before plotting
    samples_ref_np = samples_ref.cpu().numpy()
    plt.hist2d(
        samples_ref_np[:,0],
        samples_ref_np[:,1],
        range=((-5, 5), (-5, 5)),
        cmap='viridis',
        bins=50,
    )
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlabel('x1', fontsize=fs)
    plt.ylabel('x2', fontsize=fs)
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    plt.colorbar()

    # save plot
    plt.savefig('outputs/figures/pretrain/pdf_sbm.png')

if __name__ == '__main__':
    main()