import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm

from datetime import datetime

import tqdm

def get_log_dir():
    # 現在の日付と時刻を取得
    now = datetime.now()
    date_str = now.strftime("%Y-%m-%d_%H-%M-%S")  # 例: "2024-08-21_14-30-00"
    # ログディレクトリ名を作成
    log_dir = os.path.join('logs/pretrain', f'pretrain_{date_str}')
    
    return log_dir

def plot_loss(df_res, output_dir):
    # 出力先ディレクトリを指定
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # プロットの作成
    # df_res['loss']の保存
    # df_res.to_csv(os.path.join(output_dir, 'pretrain.csv'), index=False)

    # ax = df_res['loss'].plot(color="blue", alpha=0.3, label='loss')
    # df_resの点を10集めて, それらのうち中央値をプロット
    # まずは, indexとmeadianのリストを作成
    index_list = []
    min_list = []
    index_list.append(0)
    for i in range(0, len(df_res)):
        index_list.append(i+1)
        # min_listには, これまでの最小値を格納
        min_list.append(min(df_res['loss'][:i+1]))

    #  df_resに"min"を追加
    df_res['min'] = min_list
    ax = df_res['min'].plot(color="blue", label='loss')

    fs = 15

    ax.set_xlabel('epoch', fontsize=fs)
    ax.set_ylabel('loss', fontsize=fs)
    # 対数軸に設定
    # ax.set_yscale('log')
    ax.set_xscale('log')
    ax.set_ylim(0, df_res['loss'][0]+0.2)
    # プロットを指定した出力先に保存
    output_path = os.path.join(output_dir, 'pretrain.png')
    plt.savefig(output_path)
    # close the plot
    plt.close()

def pretrain_score():
    # 乱数シードを固定
    torch.manual_seed(1234)

    # 使用可能なデバイスを取得
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the model
    mm = MixtureModel(config_dir="configs", device=device)
    print("mixture model created.")

    # Sample from the distribution
    samples = mm.sample(1000)

    # Calculate log probabilities
    log_probs = mm.log_prob(samples)

    # 学習時のパラメータを設定
    with open("configs/pretrain_sbm.json", "r") as f:
        config_sbm = json.load(f)
    input_dim = config_sbm["input_dim"]
    lr = config_sbm["lr"]
    batch_size = config_sbm["batch_size"] # 2000
    num_epoch = config_sbm["num_epoch"] # 10000
    T = config_sbm["T"] # 10

    # モデル作成
    sbm = ScoreBaseModel(input_dim).to(device)
    # モデルパラメータの最適化手法を決定
    optimizer = torch.optim.Adam(sbm.parameters(), lr=lr)
    # Loss関数を決定
    criterion = torch.nn.MSELoss()

    # csvを良いみとる
    df_res = pd.read_csv('/home/***/work/doob/outputs/pretrain/20240918_162035/pretrain.csv')
    # 学習の実行
    # csvがある場合は、学習を実行しない
    if df_res.empty:
        df_res = pretrain_sbm(batch_size, num_epoch, T, mm.dist, sbm, optimizer, criterion, device, log_dir=get_log_dir())

    # 保存先ディレクトリ
    now = datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    save_dir = os.path.join('outputs', 'pretrain', now_str)

    # ディレクトリが存在しない場合は作成
    os.makedirs(save_dir, exist_ok=True)
    plot_loss(df_res, save_dir)

    # モデルの状態を保存
    torch.save(sbm.state_dict(), os.path.join(save_dir, 'score_base_model.pth'))

    # モデルを評価モードに設定
    sbm.eval()

    # 格子点を作成
    x, y = torch.meshgrid(torch.linspace(-5, 5, 30), torch.linspace(-5, 5, 30))
    point = torch.tensor(np.vstack([x.flatten(), y.flatten()]).T, device=device)
    # 時刻tを設定
    time = 0
    t = torch.ones((point.shape[0], 1), device=device) * time
    # 入力データを作成
    x_in = torch.cat([point, t], dim=1)
    # 推論モードに変更
    sbm.eval()
    # 勾配を計算しないように設定
    with torch.no_grad():
        # モデルの出力を取得
        s = sbm(x_in)
    # ベクトル場をプロット
    plt.figure()
    plt.quiver(point[:, 0].cpu().numpy(), point[:, 1].cpu().numpy(),
            s[:, 0].cpu().numpy(), s[:, 1].cpu().numpy())
    plt.title('Vector field of score-based model, t='+str(time))
    plt.xlabel('x')
    plt.ylabel('y')
    # スケールをlogではなくlinearに設定
    plt.xscale('linear')
    plt.yscale('linear')
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    # save plot
    plt.savefig(os.path.join(save_dir, 'vector_field_sbm.png'))

if __name__ == "__main__":
    pretrain_score()