import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json

import sys
sys.path.append("")
# print(sys.path)

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm


def main():

    input_dim = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # モデルのインスタンスを作成
    sbm = ScoreBaseModel(input_dim).to(device)

    # 保存したモデルの状態をロード
    sbm.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))

    # モデルを評価モードに設定
    sbm.eval()

    # 格子点を作成
    x, y = torch.meshgrid(torch.linspace(-5, 5, 30), torch.linspace(-5, 5, 30))
    point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T, device=device)
    # 時刻tを設定
    time = 0
    t = torch.ones((point.shape[0], 1), device=device) * time
    # 入力データを作成
    x_in = torch.cat([point, t], dim=1)
    # 推論モードに変更
    sbm.eval()
    # 勾配を計算しないように設定
    with torch.no_grad():
        # モデルの出力を取得
        s = sbm(x_in)
    # ベクトル場をプロット
    plt.quiver(point[:, 0].cpu().numpy(), point[:, 1].cpu().numpy(),
            s[:, 0].cpu().numpy(), s[:, 1].cpu().numpy())
    plt.title('Vector field of score-based model, t='+str(time))
    plt.xlabel('x')
    plt.ylabel('y')
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    # save plot
    plt.savefig('outputs/figures/pretrain/vector_field_sbm.png')

if __name__ == '__main__':
    main()