import torch
import numpy as np
import math, random
import json
import os

import torch.nn as nn
import torch.functional as F

import sys
sys.path.append("/home/***/work/doob")

from src.utils.sampling import doob_langevin_monte_carlo_modified, new_sbm_based_langevin_monte_carlo, calc_density_ratio
from src.doob_h.doob_h import doob_h_tensorized

from tqdm import tqdm
import gc, math

pi = torch.tensor(math.pi)

class objective_like_dpo():
    '''
    This class is for the objective function of the DPO algorithm.
    
    model_potential: the model of the potential function
    '''
    def __init__(self, device="cuda"):
        self.device = device
        self.eps = 1e-10
        self.gamma = 1 # 5000 # 3 # 0.3
        self.sample_n_samples = 10000 # 200
        self.sample_num_steps = 50 # 50
        self.p_bt_mode = 1
        self.pot_max = 0.0
        self.pot_min = 0.0
        self._load_config()

    def _load_config(self):
        # ファイルパスの設定
        config_file_path = 'configs/mixturemodel.json'# os.path.join(os.path.dirname(__file__), '/../../configs/mixturemodels.json')
        config_reg_path = 'configs/regularization.json'
        config_train_path = 'configs/train_potential.json'
        # JSONファイルの読み込み
        with open(config_file_path, 'r') as config_file:
            config = json.load(config_file)
        with open(config_reg_path, 'r') as config_file:
            config_reg = json.load(config_file)
        with open(config_train_path, 'r') as config_file:
            config_train = json.load(config_file)

        self.means = config['means']
        self.mean_win = self.means[0]
        self.mean_lose = self.means[1]
        self.covs = config['covs']
        self.weights = config['weights']
        self.gamma = config_reg['gamma']
        self.pot_max = config_reg['pot_max']
        self.pot_min = config_reg['pot_min']
        self.obj_mode = config_train['obj_mode']

    def convert_to_tensor(self, x):
        if isinstance(x, np.ndarray):
            return torch.tensor(x).to(self.device)
        return x
    
    def _sample(self, model_sbm, model_potential, T=10):
        num_steps = self.sample_num_steps
        n_samples = self.sample_n_samples
        ######## Doob: ここを変更する必要がある ########
        # 現時点のsbmモデルからサンプリング
        # return new_sbm_based_langevin_monte_carlo(model_sbm, n_samples, num_steps, T, device=self.device)
        x =  doob_langevin_monte_carlo_modified(model_sbm, model_potential, n_samples, num_steps, T, device = self.device)
        return x
    
    def _sample_from_ref(self, model_sbm, T = 10):
        num_steps = self.sample_num_steps
        n_samples = self.sample_n_samples
        return new_sbm_based_langevin_monte_carlo(model_sbm, n_samples, num_steps, T, device=self.device)

    def _sigmoidLog(self, log_ratio_win, log_ratio_lose):
        eps = self.eps
        gamma = self.gamma
        diff = -gamma * (log_ratio_win - log_ratio_lose)
        clamped_diff = torch.clamp(diff, min=-10, max=10)
        return 1 / (1 + torch.exp(clamped_diff))

    def _LogSigmoidLog(self, log_ratio_win, log_ratio_lose):
        return torch.log(self._sigmoidLog(log_ratio_win, log_ratio_lose)+self.eps)

    def _deriv_LogSigmoidLog_mult_by_p(self, log_ratio_win, log_ratio_lose, wrt = "win"):
        gamma = self.gamma
        sigmoid = self._sigmoidLog(log_ratio_win, log_ratio_lose)
        if wrt == "win":
            return gamma * (1 - sigmoid)
        elif wrt == "lose":
            return - gamma * (1 - sigmoid)

    def _p_bt(self, x_1, x_2, ):
        # x_1 が x_2 よりもmean_winに近い場合は 1 を返す. 人間のpreferenceを表すオラクル.　モデルには依存しないので, 汎函数微分には影響しない.
        # x_1: (batch_size * n_samples, input_dim)
        p_bt_mode = self.p_bt_mode
        device = self.device
        if p_bt_mode == 1:
            mean_win = torch.tensor(self.mean_win).to(device)
            mean_win = mean_win.repeat(x_1.size(0), 1) # (batch_size * n_samples, input_dim)

            # x_1[i,j] が mean_win に近いかどうかを判定し, 1-hotベクトルに変換

            dist_1 = torch.sum((x_1 - mean_win) ** 2, dim=1).to(device)
            dist_2 = torch.sum((x_2 - mean_win) ** 2, dim=1).to(device)

            def flip_with_probability_tensor(dist_1, dist_2, strongness = 1):
                # Calculate the probabilities using the sigmoid function
                probabilities = torch.sigmoid(strongness * (- dist_1 + dist_2)).to(device)
                
                # Generate random numbers between 0 and 1 with the same size as dist_1
                random_numbers = torch.rand(dist_1.size()).to(device)
                
                # Flip based on the calculated probabilities
                flips = (random_numbers < probabilities).float().to(device)
                
                return flips

            flip_results = flip_with_probability_tensor(dist_1, dist_2)

            # print("dist_1: ", dist_1)
            # print("dist_2: ", dist_2)

            return (dist_1 < dist_2).float()
        
        elif p_bt_mode == 2:
            # metric_1 = torch.sum(x_1, dim=1).to(device)
            # metric_2 = torch.sum(x_2, dim=1).to(device)
            return (x_1[:,0] > x_2[:,0]).float()
        
        elif p_bt_mode == 3:
            mean_win = torch.tensor(self.mean_win).to(device)
            mean_win = mean_win.repeat(x_1.size(0), 1) # (batch_size * n_samples, input_dim)
            mean_lose = torch.tensor(self.mean_lose).to(device)
            mean_lose = mean_lose.repeat(x_1.size(0), 1) # (batch_size * n_samples, input_dim)
            metric_1 = torch.abs(x_1[:,0] - mean_win[:,0]).to(device) - torch.abs(x_1[:,0] - mean_lose[:,0]).to(device)
            metric_2 = torch.abs(x_2[:,0] - mean_win[:,0]).to(device) - torch.abs(x_2[:,0] - mean_lose[:,0]).to(device)
            # metric_1に少しだけノイズを加える
            metric_1 = metric_1 + torch.randn_like(metric_1) * 0.1
            return (metric_1 < metric_2).float()
    
        elif p_bt_mode == 4:
            mean_win = torch.tensor(self.mean_win).to(device)
            mean_win = mean_win.repeat(x_1.size(0), 1).to(device)
            mean_lose = torch.tensor(self.mean_lose).to(device)
            mean_lose = mean_lose.repeat(x_1.size(0), 1).to(device)
            metric_1 = torch.sqrt(torch.abs(x_1[:,0] - mean_win[:,0]).to(device)) - torch.sqrt(torch.abs(x_1[:,0] - mean_lose[:,0]).to(device))
            metric_2 = torch.sqrt(torch.abs(x_2[:,0] - mean_win[:,0]).to(device)) - torch.sqrt(torch.abs(x_2[:,0] - mean_lose[:,0]).to(device))
            return (metric_1 < metric_2).float()

    def _expectation(self, x, model_potential, model_sbm, for_what, beta = 1, model_sbm_training = None):
        # x: (batch_size, input_dim)
        # print("x: ", x)
        batch_size = x.shape[0]
        gamma = self.gamma
        # samples_from_sbm = self._sample(model_sbm, model_potential) # (n_samples, input_dim)
        samples_from_ref = self._sample_from_ref(model_sbm) # (n_samples, input_dim)
        n_samples = len(samples_from_ref)
        # x, samples_from_sbm を複製して, (batch_size * n_samples, input_dim) にする
        x = x.repeat_interleave(len(samples_from_ref), dim = 0) # [x[0], x[0], ..., x[0], x[1], x[1], ..., x[1], ...]
        print("x: ", x)
        assert x[0, 0] == x[n_samples-1, 0] and x[0, 1] == x[n_samples-1, 1]
        samples = samples_from_ref.repeat(batch_size, 1) # [samples[0], samples[1], ..., samples[0], samples[1], ...]
        assert samples[0][0] == samples[n_samples][0] and samples[0][1] == samples[n_samples][1]
        ## 和を計算していく ##
        if for_what == "potential":
            p_bt_wrt_x2 = self._p_bt(x, samples)
            p_bt_wrt_x1 = self._p_bt(samples, x)
            # log ratio = - potential であることに注意
            exp_pot_ref = torch.exp(-model_potential(samples_from_ref)).squeeze(1)
            exp_pot_inv = torch.exp( model_potential(x)).squeeze(1)
            # exp_pot_invの値がtorch.exp(self.pot_max)を超える場合は, その値にする
            pot_max = torch.tensor(self.pot_max * self.gamma).to(self.device)
            # tensorの形を合わせる
            pot_max = pot_max.repeat(exp_pot_inv.size(0)).to(self.device)
            # pot_minも同様
            pot_min = torch.tensor(self.pot_min * self.gamma).to(self.device)
            pot_min = pot_min.repeat(exp_pot_inv.size(0)).to(self.device)

            exp_pot_inv = torch.where(exp_pot_inv > torch.exp(pot_max), torch.exp(pot_max), exp_pot_inv)
            # exp_pot_invの値がtorch.exp(self.pot_min)を下回る場合は, その値にする
            exp_pot_inv = torch.where(exp_pot_inv < torch.exp(pot_min), torch.exp(pot_min), exp_pot_inv)
            print("max:", torch.max(exp_pot_inv),", min:", torch.min(exp_pot_inv))
            # exp_pot_refに合わせる
            pot_max_ref = torch.tensor(self.pot_max * self.gamma).to(self.device)
            pot_max_ref = pot_max_ref.repeat(exp_pot_ref.size(0)).to(self.device)
            pot_min_ref = torch.tensor(self.pot_min * self.gamma).to(self.device)
            pot_min_ref = pot_min_ref.repeat(exp_pot_ref.size(0)).to(self.device)

            # exp_pot_refの値がtorch.exp(-self.pot_min)を超える場合は, その値にする
            exp_pot_ref = torch.where(exp_pot_ref > torch.exp(-pot_min_ref), torch.exp(pot_max_ref), exp_pot_ref)
            # exp_pot_refの値がtorch.exp(-self.pot_max)を下回る場合は, その値にする
            exp_pot_ref = torch.where(exp_pot_ref < torch.exp(-pot_max_ref), torch.exp(-pot_max_ref), exp_pot_ref)


            if self.obj_mode == 1:
                smp_wrt_x2  = self._deriv_LogSigmoidLog_mult_by_p(-model_potential(x), -model_potential(samples), wrt = "win").squeeze(1) \
                                * exp_pot_inv * p_bt_wrt_x2
                smp_wrt_x1  = self._deriv_LogSigmoidLog_mult_by_p(-model_potential(samples), -model_potential(x), wrt = "lose").squeeze(1) \
                                * exp_pot_inv * p_bt_wrt_x1
                # smp_wrt_x2が全て0以上であることを確認
                assert torch.sum(smp_wrt_x2 < 0) == 0
                # smp_wrt_x1が全て0以下であることを確認
                assert torch.sum(smp_wrt_x1 > 0) == 0
            elif self.obj_mode == 2:
                a = 1 
            assert smp_wrt_x2.shape == (batch_size * n_samples,)
            smp_wrt_x2_n_samples = smp_wrt_x2[n_samples]
            chunks_wrt_x2 = smp_wrt_x2.view(batch_size, n_samples)
            assert chunks_wrt_x2[1,0] == smp_wrt_x2_n_samples
            chunks_wrt_x1 = smp_wrt_x1.view(batch_size, n_samples)
            sums_wrt_x2 = torch.sum(chunks_wrt_x2, dim=1)
            sums_wrt_x1 = torch.sum(chunks_wrt_x1, dim=1)
            ## exp_pot_avgをかける ##
            if self.obj_mode == 1:
                exp_pot_avg = torch.sum(exp_pot_ref) / n_samples
                # print("exp_pot_sums: ", exp_pot_avg)
                sums_wrt_x2 = sums_wrt_x2 * exp_pot_avg
                sums_wrt_x1 = sums_wrt_x1 * exp_pot_avg
            elif self.obj_mode == 2:
                a = 1

            # avg_wrt_x2 = sums_wrt_x2 / count_p_bt_ones_wrt_x2
            # avg_wrt_x1 = sums_wrt_x1 / count_p_bt_ones_wrt_x1
            avg_wrt_x2 = sums_wrt_x2 / n_samples
            avg_wrt_x1 = sums_wrt_x1 / n_samples

            average = avg_wrt_x2 + avg_wrt_x1

        elif for_what == "objective": # returns (1)
            samples_from_ref_2 = self._sample_from_ref(model_sbm) # (n_samples, input_dim)
            lsl = self._LogSigmoidLog(-model_potential(samples_from_ref), -model_potential(samples_from_ref_2))
            lsl = lsl.squeeze(1) # これがないとバグる
            p_bt = self._p_bt(samples_from_ref, samples_from_ref_2)
            count_p_bt = torch.sum(p_bt)
            sum = lsl * p_bt
            # average = torch.sum(sum) / count_p_bt
            average = torch.sum(sum) / n_samples

        elif for_what == "regularized_objective": # returns (1)
            samples_from_ref_2 = self._sample_from_ref(model_sbm) # (n_samples, input_dim)
            # samples_from_ref の値の絶対値が5以下であることを確認
            # そうでない場合は, その値を表示
            if torch.max(torch.abs(samples_from_ref)) > 5:
                print("samples_from_refの値が5を超えています.")
            if torch.max(torch.abs(samples_from_ref_2)) > 5:
                print("samples_from_ref_2の値が5を超えています.")
            # 絶対値が5を超えるところを削除
            samples_from_ref = torch.clamp(samples_from_ref, -5, 5)
            samples_from_ref_2 = torch.clamp(samples_from_ref_2, -5, 5)
            # objective of DPO
            # sum = self._LogSigmoidLog(-model_potential(samples_from_ref), -model_potential(samples_from_ref_2)) * self._p_bt(samples_from_sbm, samples_from_sbm_2)
            if model_sbm_training is None:
                lsl = self._LogSigmoidLog(-model_potential(samples_from_ref), -model_potential(samples_from_ref_2))
                lsl = lsl.squeeze(1) # これがないとバグる
            else:
                print("calulating DPO with empirical density ratio.")
                lsl = self._LogSigmoidLog(torch.log(calc_density_ratio(samples_from_ref, model_sbm_training, model_sbm,
                                                                       100, 100)), 
                                        torch.log(calc_density_ratio(samples_from_ref_2, model_sbm_training, model_sbm,
                                                                       100, 100)))
                lsl = lsl.squeeze()
                assert lsl.shape == (n_samples,)
            p_bt = self._p_bt(samples_from_ref, samples_from_ref_2)
            count_p_bt = torch.sum(p_bt)
            if count_p_bt == 0:
                print("count_p_btが0です.")
                count_p_bt += 1
            sum = lsl * p_bt
            # average = torch.sum(sum) / count_p_bt
            average = torch.sum(sum) / n_samples

            # KLダイバージェンスの計算
            pot_ref = model_potential(samples_from_ref).squeeze(1).to(self.device)
            # 正規化定数
            Z = torch.sum(torch.exp(-pot_ref)) / samples_from_ref.size(0)
            Z = Z.item()
            # Zをpot_refと同じ形にする
            Z_tensor = torch.tensor([Z]).to(self.device)
            Z_tensor = Z_tensor.repeat(samples_from_ref.size(0)).to(self.device)
            assert Z_tensor.shape == pot_ref.shape
            # pot_refと同じ形の1のテンソルを作成
            KL_tensor = (- pot_ref - torch.log(Z_tensor))  * torch.exp( - pot_ref) / Z
            print("KL_tensor: ", KL_tensor.shape)
            KL = torch.sum(KL_tensor) / samples_from_ref.size(0)

            print("(not negated) DPO: ", average)
            print("KL: ", KL)
            avg_and_kl = average - beta * KL
            return avg_and_kl, average, KL
        return average

    def potential(self, x, model_potential, model_sbm): 
        ## 並列化
        model_potential, model_sbm = self._parallelize(model_potential, model_sbm)
        # x: (batch_size, input_dim)
        x = self.convert_to_tensor(x)
        # -∂F/∂p を 返す
        return - self._expectation(x, model_potential, model_sbm, "potential")
    
    def objective(self, model_potential, model_sbm):
        ## 並列化
        model_potential, model_sbm = self._parallelize(model_potential, model_sbm)
        # x は使わない
        x = torch.tensor([[0,0],[1,1]]).to(self.device)
        # -F を返す (-Fの最小化はFの最大化と同じ)
        return - self._expectation(x, model_potential, model_sbm, "objective")

    def reglarized_objective(self, model_potential, model_sbm, beta = 1, model_sbm_training = None):
        ## 並列化
        model_potential, model_sbm = self._parallelize(model_potential, model_sbm)
        # x は使わない
        x = torch.tensor([[0,0],[1,1]]).to(self.device)
        if model_sbm_training is None:
            avg_and_kl, average, KL = self._expectation(x, model_potential, model_sbm, "regularized_objective", beta)
        else:
            avg_and_kl, average, KL = self._expectation(x, model_potential, model_sbm, "regularized_objective", beta, model_sbm_training)
        return -avg_and_kl, -average, KL
    
    def _parallelize(self, model_potential, model_sbm):
        model_sbm = nn.DataParallel(model_sbm)
        model_potential = nn.DataParallel(model_potential)
        model_sbm.cuda()
        model_potential.cuda()
        return model_potential, model_sbm

    def _sample_noised_images(self, model_sbm, batch_size, t):
        # p_refからサンプル
        x_T = new_sbm_based_langevin_monte_carlo(model_sbm, batch_size, self.sample_num_steps, T=10, device=self.device)
        # p_refのサンプルにノイズを加える
        noise = torch.randn_like(x_T).to(self.device)
        t = torch.tensor(t).to(self.device)
        x_t = torch.exp(-t) * x_T + torch.sqrt(1 - torch.exp(-2*t)) * noise
        # メモリ解放
        del x_T
        gc.collect()
        torch.cuda.empty_cache()
        return x_t, noise
    
    def _predict_noise(self, model_sbm, model_potential, t, x):
        # x: (batch_size, input_dim)
        # ノイズの推定値を計算        
        batch_t = torch.ones((x.shape[0], 1)).to(self.device) * t
        ## x: (num_samples, input_dim)
        doob = doob_h_tensorized(model_sbm, model_potential, x, t, step_size=10/self.sample_num_steps)
        doob = torch.stack([doob[0], doob[1]], axis=1)
        with torch.no_grad():
            score_pred = model_sbm(torch.cat([x, batch_t], axis=1))
        t = torch.tensor(t).to(self.device)
        coeff = torch.sqrt(1 - torch.exp(-2*t))
        noise_pred_ref = - coeff * score_pred
        noise_pred_doob = - coeff * (doob+score_pred)
        ## 解放
        del doob, score_pred
        gc.collect()
        torch.cuda.empty_cache()
        return noise_pred_doob, noise_pred_ref

    def _predict_noise_training(self, model_sbm, model_sbm_training, t, x):
        batch_t = torch.ones((x.shape[0], 1)).to(self.device) * t
        with torch.no_grad():
            score_pred = model_sbm(torch.cat([x, batch_t], axis=1))
        score_pred_training = model_sbm_training(torch.cat([x, batch_t], axis=1))
        t = torch.tensor(t).to(self.device)
        coeff = torch.sqrt(1 - torch.exp(-2*t))
        noise_pred_ref = - coeff * score_pred
        noise_pred_doob = - coeff * score_pred_training
        return noise_pred_doob, noise_pred_ref

    def _calc_upperbound(self, model_sbm, model_potential=None, model_sbm_training=None, regularized=False, beta_upperbound=0.04, batch_size=250):
        sum = 0
        num_exp = 20
        kl_avg = 0.0
        for i in tqdm(range(1,num_exp+1)):
            ## t = 0 ~ 10のランダムなfloat
            t = random.uniform(0.01, 10)
            print("t: ", t)
            # 時刻tに対して, batch_size分のサンプルを取得
            # _sample_noised_imagesを使って, ノイズを出力
            x_t_w, noise_w = self._sample_noised_images(model_sbm, batch_size, t)
            x_t_l, noise_l = self._sample_noised_images(model_sbm, batch_size, t)
            x_t, noise = self._sample_noised_images(model_sbm, batch_size, t)
            # ノイズの推定値を計算
            if model_potential is not None:
                assert model_sbm_training is None
                pred_noise_w_doob, pred_noise_w_ref = self._predict_noise(model_sbm, model_potential, t, x_t_w)
                pred_noise_l_doob, pred_noise_l_ref = self._predict_noise(model_sbm, model_potential, t, x_t_l)
            else:
                assert model_sbm_training is not None
                pred_noise_w_doob, pred_noise_w_ref = self._predict_noise_training(model_sbm, model_sbm_training, t, x_t_w)
                pred_noise_l_doob, pred_noise_l_ref = self._predict_noise_training(model_sbm, model_sbm_training, t, x_t_l)
                pred_noise_doob, pred_noise_ref = self._predict_noise_training(model_sbm, model_sbm_training, t, x_t)
            error_w = 0.5 * 10 * (torch.norm(pred_noise_w_doob - noise_w, dim=1)**2 - torch.norm(pred_noise_w_ref - noise_w, dim=1)**2)
            error_l = 0.5 * 10 * (torch.norm(pred_noise_l_doob - noise_l, dim=1)**2 - torch.norm(pred_noise_l_ref - noise_l, dim=1)**2)
            sigmoid = self._LogSigmoidLog(error_w, error_l)
            # sigmoid = self._LogSigmoidLog(-error_w, -error_l) # this causes a bug
            # p_bt
            p_bt = self._p_bt(x_t_w, x_t_l)
            sum += torch.sum(p_bt * sigmoid)

            kl_avg += 10 * torch.sum(torch.norm(pred_noise_w_doob - pred_noise_w_ref, dim=1)**2 + torch.norm(pred_noise_l_doob - pred_noise_l_ref, dim=1)**2) * (1 - math.exp(-2*t)) / 4

            del x_t_w, x_t_l, noise_w, noise_l, pred_noise_w_doob, pred_noise_w_ref, \
                pred_noise_l_doob, pred_noise_l_ref, error_w, error_l, sigmoid, p_bt
            gc.collect()
            torch.cuda.empty_cache()

        if model_potential is not None:
            avg = - sum.item() / num_exp / batch_size
        else:
            assert model_sbm_training is not None
            avg = - sum / num_exp / batch_size

        if regularized == False:
            return avg, 0
        else:
            kl_avg = kl_avg / num_exp / batch_size
            return avg, kl_avg
        