import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

import json

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.potentials.target_potential import new_target_potential
import matplotlib.pyplot as plt

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    with open('configs/train_potential.json') as f:
        config_potential = json.load(f)

    num_samples = config_potential["num_samples"] # 10000
    lr = config_potential["lr"] # 0.0001
    num_epochs = config_potential["num_epochs"] # 1000
    
    x_train = torch.rand(num_samples, 2).to(device) * 10 - 5  # [-5, 5] の範囲でランダムに生成
    # y_train = torch.tensor([new_target_potential(x) for x in x_train], dtype=torch.float32).unsqueeze(1).to(device) # Convert y_train to float32
    y_train = new_target_potential(x_train.T).unsqueeze(1)
    # y_trainの最小値を0にする
    y_train = y_train - y_train.min()

    print("x_train: ", x_train[:10])
    print("y_train: ", y_train[:10])

    # モデル、オプティマイザ、損失関数の定義
    model_potential = ApproxPotential(2).to(device)
    optimizer_potential = torch.optim.Adam(model_potential.parameters(), lr=lr)
    criterion_potential = nn.MSELoss()


    ## 学習 ##
    losses = []
    for epoch in tqdm(range(num_epochs)):
        model_potential.train()
        optimizer_potential.zero_grad()
        y_pred = model_potential(x_train)
        loss = criterion_potential(y_pred, y_train)
        loss.backward()
        optimizer_potential.step()
        losses.append(loss.item())

    ## セーブ ##
    save_dir = 'outputs/potential/'

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # モデルの状態を保存
    torch.save(model_potential.state_dict(), os.path.join(save_dir, 'model_potential.pth'))

if __name__ == "__main__":
    main()