import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json

import sys
sys.path.append("/home/***/work/doob")
# print(sys.path)

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm
from src.utils.sampling import model_based_langevin_monte_carlo

def main():
    input_dim = 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # configs/mixturemodel.jsonからパラメータを読み込む
    with open("configs/mixturemodel.json") as f:
        mixturemodel_params = json.load(f)
    mean_win = mixturemodel_params["means"][0]

    # モデルのインスタンスを作成
    sbm = ScoreBaseModel(input_dim).to(device)

    # 保存したモデルの状態をロード
    # model_path = 'outputs/pretrain/score_base_model.pth'
    # klなし
    # model_path = '/home/***/work/doob/outputs/diffusionDPO/train_20241112/score_base_model_50.pth'
    # klなし
    # model_path = '/home/***/work/doob/outputs/diffusionDPO/train_20241108/score_base_model_40.pth'
    # klあり
    # model_path = "/home/***/work/doob/outputs/diffusionDPO/train_20241116_214506/score_base_model_200.pth" 
    ## 1129 やり直し klなし, gamma=0.1
    model_path = "/home/***/work/doob/outputs/diffusionDPO/train_20241129_170540/score_base_model_50.pth" 
    ## klあり, gamma=0.1
    # model_path = "/home/***/work/doob/outputs/diffusionDPO/train_20241129_140958/score_base_model_50.pth" 
    sbm.load_state_dict(torch.load(model_path))

    # モデルを評価モードに設定
    sbm.eval()

    # ランジュバン・モンテカルロ法のパラメータ
    num_samples = 100000 # 1000 for dists_mean
    num_steps = 100
    T = 10
    step_size = T / num_steps

    # サンプリングの実行
    samples_ref = model_based_langevin_monte_carlo(sbm, num_samples, num_steps, step_size, device=device)

    # サンプリングの実行
    # サンプリング結果の可視化
    plt.title('model-based langevin monte carlo sampling')
    # Move the tensor to the CPU and convert it to a NumPy array before plotting
    samples_ref_np = samples_ref.cpu().numpy()
    plt.hist2d(
        samples_ref_np[:,0],
        samples_ref_np[:,1],
        range=((-5, 5), (-5, 5)),
        cmap='viridis',
        bins=50,
    )
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    plt.colorbar()

    # samples_ref_npの各点に対して, mean_winとの距離を計算
    dists = np.linalg.norm(samples_ref_np - mean_win, axis=1)
    # distsが高すぎるときは, 除外
    # 高すぎるとは, 10以上のとき, それが何点あるかをprint
    print("number of samples with distance > 10: ", (dists > 10).sum())
    dists = dists[dists < 10]
    # distsの平均を計算, print
    dists_mean = np.mean(dists)
    print("mean distance between samples and mean_win: ", dists_mean)

    ## samples_refをcsvで保存
    df = pd.DataFrame(samples_ref_np, columns=['x', 'y'])
    # datetimeを取得
    import datetime
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d%H%M%S')
    df.to_csv(f"outputs/pretrain/samples_from_ref_{now_str}.csv", index=False)

    # save plot
    plt.savefig(f"outputs/figures/pretrain/pdf_sbm_{now_str}.png")

if __name__ == '__main__':
    main()