import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

import sys
sys.path.append("/home/***/work/doob")

from src.models.nn_potential import ApproxPotential
from src.models.sbm import ScoreBaseModel
from src.objective.obj import objective_like_dpo
from src.potentials.potential_for_inner_loop import Potential_for_inner_loop
from src.utils.set_seed import set_seed

import matplotlib.pyplot as plt

import math

def xavier_init(model, beta=1):
    """ Xavierの初期化

    Args:
        model (object): モデル
    """
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            param.data.fill_(0)
        else:
            if beta == 0:
                bound = 0
            else:
                bound = math.sqrt(6)/math.sqrt(param.shape[0]+param.shape[1])/beta
            param.data.uniform_(-bound, bound)

def main():
    print("## testing objective ##")
    idx_loop = 1
    # set_seed(111) # (idx_loop)
    set_seed(222)
    device = "cuda"
    config_reg_path = 'configs/regularization.json'
    with open(config_reg_path) as f:
        config_reg = json.load(f)
    beta = config_reg["beta"]
    # previousPotential = ApproxPotential(2).to(device)

    previousPotential = ApproxPotential(2).to(device)
    xavier_init(previousPotential, beta=0)
    # previousPotential.load_state_dict(torch.load('outputs/potential/model_potential.pth'))

    model_sbm = ScoreBaseModel(2).to(device)
    model_sbm.load_state_dict(torch.load('outputs/pretrain/score_base_model.pth'))

    objective_class = objective_like_dpo(device)
    x = torch.tensor([
        [0.0, 0.0],
        [-2.5, 0.0],
        [2.5, 0.0],
        [0.0, 2.5],
        [0.0, -2.5],
        [-3, 0.0],
        [3, 0.0],
        [0.0, 3],
        [0.0, -3],
        ]).to(device)

    print("p_bt_mode: ", objective_class.p_bt_mode)
    print("beta: ", beta)
    print("gamma: ", objective_class.gamma)

    ratio1 = torch.tensor([1, 1])
    ratio2 = torch.tensor([2, 2])
    lsl_win = objective_class._LogSigmoidLog(ratio1, ratio2)
    lsl_lose = objective_class._LogSigmoidLog(ratio2, ratio1)
    print("win: ", lsl_win, ",  lose: ", lsl_lose)

    p_bt = objective_class._p_bt
    x_1 = torch.tensor([[1.0, 1.0],[1.0, 1.0],[0.0, 0.0],[-4.9, -4.9],[1.8, 0.0]]).to(device)
    x_2 = torch.tensor([[2.0, 1.5],[-1.0, 1.5],[-1.0, -1.0],[-4.8, -4.0], [2.2, 0.0]]).to(device)
    print("p_bt: ", p_bt(x_1, x_2))

    mean_win = objective_class.mean_win
    print("mean_win: ", mean_win)

    lsl = objective_class._LogSigmoidLog(-previousPotential(x_1), -previousPotential(x_2))
    print("lsl: ", lsl.shape, lsl)
    pbt = p_bt(x_1, x_2)
    print("pbt: ", pbt.shape, pbt)
    print(lsl.squeeze(1) * pbt)

    lsl_0 = objective_class._LogSigmoidLog(torch.tensor(1),torch.tensor(1))
    print("lsl_0: ", lsl_0)

    # exit()

    print("x: ", x)

    for i in range(len(x)):
        print("x: ", x[i], ", previousPotential: ", previousPotential(x[i]))
    
    if True: # you may skip or do this part
        if idx_loop == 1:
            pot = objective_class.potential(x, previousPotential, model_sbm) / (3*beta)
        else:
            pot = objective_class.potential(x, previousPotential, model_sbm) * idx_loop / (idx_loop+2) / beta
        for i in range(len(x)):
            print("x: ", x[i], ", potential: ", pot[i])

        obj, dpo, kl = objective_class.reglarized_objective(previousPotential, model_sbm, beta)
        print("regularized_obj: ", obj, ", dpo: ", dpo, ", kl: ", kl)

    with torch.no_grad():
        upperbound, kl_girsanov = objective_class._calc_upperbound(model_sbm, model_sbm_training=model_sbm, regularized=True, beta_upperbound=beta, batch_size=10000)
        print("upperbound: ", upperbound, ", kl_girsanov: ", kl_girsanov)
    # idx_loop = 1
    # potential = Potential_for_inner_loop(beta, device)
    # potential._prepare_data(objective_class, previousPotential, model_sbm, idx_loop)

    print("## testing objective finished ! ##")

if __name__ == "__main__":
    main()