import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed
from src.utils.sampling import calc_density_ratio

import matplotlib.pyplot as plt

import math

import gc

class model_potential_zero(nn.Module):
    def __init__(self, input_dim):
        super(model_potential_zero, self).__init__()
        self.input_dim = input_dim

    def forward(self, x):
        return torch.zeros(x.shape[0], 1).to(x.device)

def main():
    print("## training upepr bound ##")
    set_seed(222)
    device = "cuda:0"
    model_sbm_training = ScoreBaseModel(2).to(device)
    model_sbm_training.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))
    # 並列化
    # model_sbm_training = nn.DataParallel(model_sbm, output_device=1)

    # reference model
    model_sbm_ref = ScoreBaseModel(2).to(device)
    model_sbm_ref.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))
    # 並列化
    # model_sbm_ref = nn.DataParallel(model_sbm_ref, output_device=1)

    objective_class = objective_like_dpo(device)

    import wandb

    model_sbm_training.train()
    model_sbm_ref.eval()

    wandb.init(project="MoG-upperbound-train")
    lr = 1e-4
    optimizer = torch.optim.Adam(model_sbm_training.parameters(), lr=lr) # 0.0001 or unregularized, 0.0005 for regularized
    regularized = False
    wandb.config.update({"lr": lr, "regularized": regularized})

    model_potential = model_potential_zero(2).to(device)
    # model_potential = nn.DataParallel(model_potential, output_device=1)
    ## 念の為, 0を返すか確認
    x_test = torch.randn(100, 2).to(device)
    y_test = model_potential(x_test)
    assert torch.all(y_test == 0.0)

    config_reg_path = 'configs/regularization.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]

    # make the directory named "outputs/diffusionDPO" if it does not exist
    if not os.path.exists("outputs/diffusionDPO"):
        os.makedirs("outputs/diffusionDPO")

    # make directory "train_{yyyymmdd}" in diffusionDPO
    date = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    if not os.path.exists(f"outputs/diffusionDPO/train_{date}"):
        os.makedirs(f"outputs/diffusionDPO/train_{date}")
    save_dir = f"outputs/diffusionDPO/train_{date}"

    n_train = 200
    model_sbm_training.train()
    for train_loop in range(1, n_train+1):
        ## calculate the true DPO loss
        with torch.no_grad():
            model_sbm_training.eval()
            true_loss, dpo, kl = objective_class.reglarized_objective(model_potential=model_potential, model_sbm=model_sbm_ref, beta=beta, model_sbm_training=model_sbm_training) 
            ## calc the density ratio
            x_input = torch.randn(10, 2).to(device)   
            density_ratio = calc_density_ratio(x_input, model_sbm_training, model_sbm_ref, 100, 100, T=10, device=device)
            print("density_ratio: ", density_ratio)
        optimizer.zero_grad()
        # 以下、目的関数の計算
        upperbound, kl_girsanov = objective_class._calc_upperbound(model_sbm_ref, model_sbm_training=model_sbm_training, regularized=regularized, beta_upperbound=beta)
        regularized_upperbound = upperbound + beta * kl_girsanov
        # 勾配を計算して, model_sbm_trainingを更新
        regularized_upperbound.backward()
        optimizer.step()
        print("true_loss: ", true_loss, "upperbound: ", upperbound, "kl_girsanov: ", kl_girsanov, "regularized_upperbound: ", regularized_upperbound)
        wandb.log({"train_loop": train_loop, "upperbound": upperbound, "regularized upperbound": regularized_upperbound, "true loss": true_loss, "DPO": dpo, "KL": kl})
        # save model
        if train_loop % 1 == 0:
            # join path
            model_path = os.path.join(save_dir, f"score_base_model_{train_loop}.pth")
            torch.save(model_sbm_training.state_dict(), model_path)
            print("model is saved")

if __name__ == "__main__":
    main()