import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

import sys
sys.path.append("")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed

import matplotlib.pyplot as plt

import math

import gc

from matplotlib.colors import Normalize

def main():
    print("## heatmap ##")
    set_seed(223)
    device = "cuda:0"
    model_sbm = ScoreBaseModel(2).to(device)
    model_sbm.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))
    # 並列化
    model_sbm = nn.DataParallel(model_sbm)
    objective_class = objective_like_dpo(device)

    # load config
    config_reg_path = 'configs/regularization.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]

    innerloop = Potential_for_inner_loop(beta=beta, device=device)

    fig, axs = plt.subplots(4, 4, figsize=(10, 10))

    # Define the common color normalization across all heatmaps
    vmin = 0  # Minimum value for color scale
    vmax = 2  # Maximum value for color scale
    norm = Normalize(vmin=vmin, vmax=vmax)

    for idx_loop in range(1, 16+1):
        print("idx_loop: ", idx_loop)
        # axsの行列
        row = (idx_loop-1) // 4
        col = (idx_loop-1) % 4
        Potential = ApproxPotential(2).to(device)
        potential_path = "outputs/figures/dual_averaging/20240916_105414/model_potential_"+str(idx_loop)+".pth"
        Potential.load_state_dict(torch.load(potential_path))
        Potential = nn.DataParallel(Potential, device_ids=[0], output_device=0)
        Potential.eval()    

        x = torch.linspace(-5, 5, 30).to(device)
        y = torch.linspace(-5, 5, 30).to(device)
        X, Y = torch.meshgrid(x, y)
        X = X.to(device)
        Y = Y.to(device)
        Z = torch.zeros_like(X).to(device)
        # for i in range(len(X)):
        #     for j in range(len(Y)):
        #         Z[i, j] = potential(torch.tensor([X[i, j], Y[i, j]]).to(self.device))
        # 上記は、以下のようにも書ける
        Z = torch.stack([X, Y], dim=2).view(-1, 2).to(device)
        Z = Potential(Z).view(X.shape)
        cf = axs[row][col].contourf(X.cpu().detach().numpy(), Y.cpu().detach().numpy(), Z.cpu().detach().numpy(), levels=100, cmap='jet', norm=norm)
        axs[row][col].set_title("k="+str(idx_loop), fontsize=20)
        axs[row][col].set_xlabel("x_1", fontsize=15)
        axs[row][col].set_ylabel("x_2", fontsize=15)
        axs[row][col].axis('equal')
        axs[row][col].axis('off')
        ## [2.5,0]に星をつける
        axs[row][col].scatter(2.5, 0, s=100, c='red', marker='*', label='target')
        fig.colorbar(cf, ax=axs[row][col], orientation='vertical')
        print("potential_path: ", potential_path)
    
    os.makedirs("outputs/figures/heatmap", exist_ok=True)
    # save figure
    # 日付
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    plt.tight_layout()
    # ファイル名に日付を含める
    plt.savefig("outputs/figures/heatmap/heatmap_"+now_str+".png")
    
    print("## end testing objective ##")

if __name__ == "__main__":
    main()