import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json
import copy
import gc

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed
from src.models.init_models import xavier_init

import matplotlib.pyplot as plt

import wandb

def main():
    # set_seed(111) ## preferenceの真逆
    set_seed(222)

    print("## start dual averaging ##")

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "outputs/figures/dual_averaging/"+now_str
    os.makedirs(dirname, exist_ok=True)

    device = "cuda"
    config_reg_path = 'configs/regularization.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]
    beta_prime = config_reg["beta_prime"]
    gamma = config_reg["gamma"]
    y_min = config_reg["y_min"]
    y_max = config_reg["y_max"]
    pot_min = config_reg["pot_min"]
    pot_max = config_reg["pot_max"]
    print("beta: ", beta, ", gamma: ", gamma)

    objective_class = objective_like_dpo(device)
    # previousPotential = ApproxPotential(2).to(device)

    previousPotential = ApproxPotential(2).to(device)
    xavier_init(previousPotential)
    # previousPotential.load_state_dict(torch.load('outputs/potential/model_potential.pth'))

    sbm_path = 'outputs/pretrain/score_base_model.pth'
    model_sbm = ScoreBaseModel(2).to(device)
    model_sbm.load_state_dict(torch.load(sbm_path))
    # sbm_pathを保存
    with open(dirname+'/sbm_path.txt', mode='w') as f:
        f.write(sbm_path)

    # print("regularized_objective: ", objective_class.reglarized_objective(previousPotential, model_sbm, beta))

    max_iter = 100

    p_bt_mode = objective_class.p_bt_mode

    # fileに保存
    with open(dirname+'/beta_gamma.txt', mode='w') as f:
        f.write("beta: "+str(beta)+"\n")
        f.write("beta_prime: "+str(beta_prime)+"\n")
        f.write("gamma: "+str(gamma)+"\n")
        f.write("y_min: "+str(y_min)+"\n")
        f.write("y_max: "+str(y_max)+"\n")
        f.write("pot_min: "+str(pot_min)+"\n")
        f.write("pot_max: "+str(pot_max)+"\n")
        f.write("max_iter: "+str(max_iter)+"\n")
        f.write("p_bt_mode: "+str(p_bt_mode)+"\n")

    # dirnameと結合
    filename_obj = os.path.join(dirname, 'da_loss_obj.txt')
    filename_dpo = os.path.join(dirname, 'da_loss_dpo.txt')
    filename_kl = os.path.join(dirname, 'da_loss_kl.txt')

    # wandbの初期化
    wandb.init(
        project="dual_averaging",
        name=f"{now_str}_beta={beta}_bprime={beta_prime}_gamma={gamma}",
        config=config_reg
    )

    loss_obj_list = []
    loss_dpo_list = []
    loss_kl_list = []
    loss_beta_kl_list = []
    for idx_loop in tqdm(range(1, max_iter + 1)):
        if idx_loop > 1:
            del potential_for_inner_loop
            gc.collect()
            torch.cuda.empty_cache()
        potential_for_inner_loop = Potential_for_inner_loop(beta = beta, beta_prime = beta_prime, device=device)
        if idx_loop > 1:
            del previousPotential
            gc.collect()
            torch.cuda.empty_cache()
            previousPotential = newPotential
        newPotential = potential_for_inner_loop.train_potential(previousPotential, model_sbm, idx_loop, savedirname=dirname, plot_data=True)
        loss_obj, loss_dpo, loss_kl = objective_class.reglarized_objective(newPotential, model_sbm, beta)
        loss_obj_list.append(loss_obj.item())
        loss_dpo_list.append(loss_dpo.item())
        loss_kl_list.append(loss_kl.item())
        loss_beta_kl_list.append(beta * loss_kl.item())
        print("## inner-loop ended. loss: ", loss_obj, ", idx_loop: ", idx_loop, ", beta: ", beta, ", loss_dpo: ", loss_dpo, ", loss_kl: ", loss_kl)
        # wandbに保存
        wandb.log({"loss_obj": loss_obj, "loss_dpo": loss_dpo, "loss_kl": loss_kl, "loss_beta_kl": beta * loss_kl, "idx_loop": idx_loop})

    with open(filename_obj, 'w') as file:
        for loss_obj in loss_obj_list:
            file.write(f"{loss_obj}\n")  # 各要素を一行ずつ書き込む

    with open(filename_dpo, 'w') as file:
        for loss_dpo in loss_dpo_list:
            file.write(f"{loss_dpo}\n")
    
    with open(filename_kl, 'w') as file:
        for loss_kl in loss_kl_list:
            file.write(f"{loss_kl}\n")

    print("## outer-loop ended. loss: ", loss_obj_list, ", beta: ", beta)

    # loss から, lossの最小値を引いて, lossのグラフを描画
    loss_obj_list = np.array(loss_obj_list)
    loss_obj_list = loss_obj_list - np.min(loss_obj_list) + 1e-5
    loss_dpo_list = np.array(loss_dpo_list)
    loss_dpo_list = loss_dpo_list - np.min(loss_dpo_list) + 1e-5
    loss_kl_list = np.array(loss_beta_kl_list)
    loss_kl_list = loss_beta_kl_list - np.min(loss_beta_kl_list) + 1e-5

    plt.figure(figsize=(10, 5))
    plt.plot(loss_obj_list, label="Regularized Objective")
    plt.plot(loss_dpo_list, label="Loss of DPO")
    plt.plot(loss_beta_kl_list, label="beta * KL")
    plt.yscale("log")
    plt.xlabel("idx_loop")
    plt.ylabel("loss")
    plt.ylim(1e-3, np.max(loss_obj_list))
    plt.title("loss of dual averaging"+", beta: "+str(beta)+", gamma: "+str(gamma), fontsize=20)
    plt.legend()
    plt.grid()
    # 日付をstrで取得
    # ファイル名に日付を含める
    # filename = 'outputs/figures/dual_averaging/da_loss_' + now_str + '.png'
    filename = os.path.join(dirname, 'da_loss.png')
    plt.savefig(filename)

    # outputs/pretrain/samples_from_ref.csv からサンプリング結果を読み込む
    # 絡むは, x, y の2次元
    samples_ref = pd.read_csv('outputs/pretrain/samples_from_ref.csv')
    samples_ref = samples_ref.to_numpy()
    samples_ref = torch.tensor(samples_ref).to(device)
    print("samples_ref: ", samples_ref.shape)
    # new_potentialのヒートマップを描画
    x = torch.linspace(-5, 5, 50)
    y = torch.linspace(-5, 5, 50)
    X, Y = torch.meshgrid(x, y)
    X = X.to(device)
    Y = Y.to(device)
    Z_first = torch.zeros_like(X).to(device)
    Z_final = torch.zeros_like(X).to(device)
    delta = 0.2
    if max_iter == 1:
        newPotential = previousPotential
    for i in range(len(X)):
        for j in range(len(Y)):
            # x, yの格子の中に, samples_refの点がいくつあるかをカウントして, Z[i,j] に格納
            xy = torch.tensor([X[i, j], Y[i, j]]).to(device)
            # print("xy: ", xy)
            # print("new_potential: ", new_potential(xy))
            hist = torch.sum(((samples_ref[:,0] > X[i, j] - delta) & (samples_ref[:,0] < X[i, j] + delta) & (samples_ref[:,1] > Y[i, j] - delta) & (samples_ref[:,1] < Y[i, j] + delta)).float())
            Z_first[i, j] = hist
            Z_final[i, j] = torch.exp(-newPotential(xy)) * hist
    # 以下のように簡略化できる
    # Z = torch.exp(-new_potential(torch.stack([X, Y], dim=2))) * torch.sum((samples_ref[:,0] > X - 0.1) & (samples_ref[:,0] < X + 0.1) & (samples_ref[:,1] > Y - 0.1) & (samples_ref[:,1] < Y + 0.1), dim=0)
    # Z = Z.cpu().detach().numpy()

    # plt.contourf(X.cpu().detach().numpy(), Y.cpu().detach().numpy(), Z.cpu().detach().numpy(), levels=100, cmap='jet')
    # mesh状
    plt.figure()
    plt.pcolormesh(X.cpu().detach().numpy(), Y.cpu().detach().numpy(), Z_first.cpu().detach().numpy(), cmap='jet')
    plt.colorbar()
    plt.title("pretrained model")
    # 日付をstrで取得
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d%H%M%S')
    # ファイル名に日付を含める
    # filename = 'outputs/figures/dual_averaging/da_potential_' + now_str + '.png'
    filename = os.path.join(dirname, 'da_potential_first.png')
    plt.savefig(filename)

    plt.figure()
    plt.title("potential_final"+", beta: "+str(beta)+", gamma: "+str(gamma), fontsize=20)
    plt.pcolormesh(X.cpu().detach().numpy(), Y.cpu().detach().numpy(), Z_final.cpu().detach().numpy(), cmap='jet')
    plt.colorbar()
    filename = os.path.join(dirname, 'da_potential_final.png')
    plt.savefig(filename)
    print("完了")

if __name__ == "__main__":
    main()