import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions import MixtureSameFamily, Categorical
import json
import datetime

import gc

import sys
sys.path.append("/home/***/work/doob")
# print(sys.path)

from src.datasets.dist_for_ref import MixtureModel
from src.models.sbm import ScoreBaseModel, pretrain_sbm
from src.doob_h.doob_h import doob_h_tensorized # doob_h
from src.models.nn_potential import ApproxPotential
from src.utils.sampling import doob_langevin_monte_carlo_modified, new_sbm_based_langevin_monte_carlo

import wandb

def plot(x_doobs, x_refs, sbm_path, potential_paths, dir):
    # Create subplots
    # x_doobsの長さ+1のサブプロットを作成
    len_x_doobs = x_doobs.shape[0]
    fig, axs = plt.subplots(1, len_x_doobs+1, figsize=(5*len_x_doobs+2, 5))
    # fig, axs = plt.subplots(2,2, figsize=(10,10))
    x_doob_np = x_doobs.cpu().numpy()
    x_ref_np = x_refs.cpu().numpy()
    # subplotの間を詰める
    # plt.subplots_adjust(wspace=0.25, hspace=0.25)

    n_bins = 30
    fs = 15

    # Plot x_ref
    axs[0].set_title('reference', fontsize=fs*2)
    axs[0].hist2d(
        x_ref_np[0,:,0],
        x_ref_np[0,:,1],
        range=((-5, 5), (-5, 5)),
        cmap='viridis',
        bins=n_bins,
    )
    # axs[0].set_aspect('equal', adjustable='box')
    axs[0].set_aspect('equal')
    axs[0].set_xlim([-5, 5])
    axs[0].set_ylim([-5, 5])
    axs[0].set_xlabel('x1', fontsize=fs)
    axs[0].set_ylabel('x2', fontsize=fs)

    iters = [1,2,100]

    # Plot x_doob
    for i in range(len_x_doobs):
        axs[1+i].set_title('iter = ' + str(iters[i]), fontsize=fs*2)
        axs[1+i].hist2d(
            x_doob_np[i,:,0],
            x_doob_np[i,:,1],
            range=((-5, 5), (-5, 5)),
            cmap='viridis',
            bins=n_bins,
        )
        axs[1+i].set_aspect('equal')
        axs[1+i].set_xlim([-5, 5])
        axs[1+i].set_ylim([-5, 5])
        axs[1+i].set_xlabel('x1', fontsize=fs)
        axs[1+i].set_ylabel('x2', fontsize=fs)

    filename = os.path.join(dir, 'doob_h_monte_carlo.png')
    plt.savefig(filename)
    # potential_pathの名前を保存
    with open(os.path.join(dir, 'model_path.txt'), mode='w') as f:
        f.write(sbm_path + '\n')
        for potential_path in potential_paths:
            f.write(potential_path + '\n')

import os
import matplotlib.pyplot as plt
import seaborn as sns

def KDE_plot(x_doobs, x_refs, sbm_path, potential_paths, dir):
    # Create subplots
    len_x_doobs = x_doobs.shape[0]
    fig, axs = plt.subplots(1, len_x_doobs + 1, figsize=(5 * len_x_doobs + 4, 5))
    x_doob_np = x_doobs.cpu().numpy()
    x_ref_np = x_refs.cpu().numpy()

    # Adjust subplot spacing
    # plt.subplots_adjust(wspace=0.2, hspace=0.3)

    xmax = 5
    xmin = -5
    ymax = 5
    ymin = -5

    # fontsize
    fs = 30
    fs_axis = 15
    levels = 100

    # Plot x_ref with Kernel Density Estimation (KDE)
    axs[0].set_title('reference', fontsize=fs)
    sns.kdeplot(
        x=x_ref_np[0, :, 0], 
        y=x_ref_np[0, :, 1], 
        fill=True, 
        cmap='Blues', 
        ax=axs[0],
        levels=levels,
        clip=[[-5, 5], [-5, 5]]
    )
    axs[0].set_aspect('equal', adjustable='box')
    axs[0].set_xlim([-5, 5])
    axs[0].set_ylim([-5, 5])
    # x軸を表示
    axs[0].set_xlabel('x1', fontsize=fs_axis)
    # y軸を表示
    axs[0].set_ylabel('x2', fontsize=fs_axis)

    # Plot x_doob with Kernel Density Estimation (KDE)
    for i in range(len_x_doobs):
        axs[1 + i].set_title('k = ' + str(i + 1), fontsize=fs)
        sns.kdeplot(
            x=x_doob_np[i, :, 0], 
            y=x_doob_np[i, :, 1], 
            fill=True, 
            cmap='Blues', 
            ax=axs[1 + i],
            levels=levels,
            clip=[[-5, 5], [-5, 5]]
        )
        axs[1 + i].set_aspect('equal', adjustable='box')
        axs[1 + i].set_xlim([-5, 5])
        axs[1 + i].set_ylim([-5, 5])
        # x軸を表示
        axs[1 + i].set_xlabel('x1', fontsize=fs_axis)
        # y軸を表示
        axs[1 + i].set_ylabel('x2', fontsize=fs_axis)

    # Save the plot as an image
    filename = os.path.join(dir, 'doob_h_monte_carlo_kde.png')
    plt.savefig(filename)

    try:
        # wandbに画像をログ
        wandb.log({"plot": wandb.Image(filename)})
    except:
        # 警告
        print("wandbに画像をログできませんでした")

    # Save the model paths to a text file
    with open(os.path.join(dir, 'model_path.txt'), mode='w') as f:
        f.write(sbm_path + '\n')
        for potential_path in potential_paths:
            f.write(potential_path + '\n')


def main():
    # load config
    config_doob_h_path = 'configs/doob_h_1116.json'
    with open(config_doob_h_path) as f:
        config_doob_h = json.load(f)

    wandb.init(
        project="mog_doob",
        config=config_doob_h
    )

    # sampled_dir = None # 
    sampled_dir = config_doob_h["sampled_dir"]
    # 日付
    date = pd.Timestamp.now().strftime('%Y%m_%d%H%M%S')
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    if sampled_dir == None:
        # ディレクトリが存在しない場合は作成
        dir = os.path.join('outputs', 'figures', 'doob_h_MC', now_str)
        os.makedirs(dir, exist_ok=True)
    else:
        dir = os.path.join('outputs', 'figures', 'doob_h_MC_plot_only', now_str)
        os.makedirs(dir, exist_ok=True)
    # dir と ファイル名を結合

    sbm_path = config_doob_h["sbm_path"]
    sbm = ScoreBaseModel(2)
    sbm.load_state_dict(torch.load(sbm_path))
    sbm.eval()
    sbm = nn.DataParallel(sbm)
    sbm.cuda()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    T = 10
    num_steps = 50
    num_repeat_samples = 1 # 10 #0
    num_samples_i = 100 # 50 くらいまではいける

    x_doob = []
    x_ref = []
    ## 複数GPU使えばこのfor文が高速化できそう

    print("### Start doob h ###")

    #configの中に, potential_path_1, potential_path_2, ... があるとして, それを読み込む
    # potential_path_iがいくつあるか
    path_num = config_doob_h["path_num"]
    potential_paths = []
    for i in range(1, path_num+1):
        potential_paths.append(config_doob_h["potential_path_"+str(i)])

    x_doobs = []
    x_refs = []
    path_iter = 0

    # configs/mixturemodel.jsonからパラメータを読み込む
    with open("configs/mixturemodel.json") as f:
        mixturemodel_params = json.load(f)
    mean_win = mixturemodel_params["means"][0]

    ## dist_ref_meanを複数保存する辞書
    dists_ref_means_dict = {}

    if sampled_dir is None:
        for potential_path in potential_paths:
            print("### potential_path: ", potential_path, "###")
            x_doob = []
            x_ref = []
            model_potential = ApproxPotential(2)
            model_potential.load_state_dict(torch.load(potential_path))
            model_potential = nn.DataParallel(model_potential)
            model_potential.cuda()
            model_potential.eval()
            for i in tqdm(range(num_repeat_samples)):
                # x_doob_i, x_ref_i = doob_langevin_monte_carlo(sbm, model_potential, num_samples_i, num_steps, step_size)
                x_doob_i = doob_langevin_monte_carlo_modified(sbm, model_potential, num_samples_i, num_steps, T, device = device)
                x_doob.append(x_doob_i)
                if path_iter == 0:
                    x_ref_i = new_sbm_based_langevin_monte_carlo(sbm, num_samples_i, num_steps, T, device = device)
                    x_ref.append(x_ref_i)
            # x_doobをcsvで保存
            x_doob = torch.cat(x_doob, axis=0)
            # 保存するパスは os.path.join(dir, 'doob_i.csv')
            x_doob_df = pd.DataFrame(x_doob.cpu().numpy())
            x_doob_df.to_csv(os.path.join(dir, 'doob'+str(path_iter)+'.csv'))
            x_doobs.append(x_doob)
            if path_iter == 0:
                x_ref = torch.cat(x_ref, axis=0)
                x_ref_df = pd.DataFrame(x_ref.cpu().numpy())
                x_ref_df.to_csv(os.path.join(dir, 'ref'+str(path_iter)+'.csv'))
                x_refs.append(x_ref)
            print("### Finished potential_path: ", potential_path, "###")
            print("x_doob: ", x_doob.shape)
            path_iter += 1
    else:
        # sampled_dirが指定されている場合は, その中にあるcsvファイルを読み込む
        # sampled_dirの中には, doob0.csv, doob1.csv, ... がある
        for i in range(path_num):
            doob_path = os.path.join(sampled_dir, 'doob'+str(i)+'.csv')
            print("### doob_path: ", doob_path, "###")
            x_doob_i = pd.read_csv(doob_path, header=None, index_col=0)
            x_doob_i = torch.tensor(x_doob_i.values)
            # length of doob_i
            print("x_doob_i: ", x_doob_i.shape)
            x_doobs.append(x_doob_i)
            if i == 0:
                ref_path = os.path.join(sampled_dir, 'ref'+str(i)+'.csv')
                x_ref_i = pd.read_csv(ref_path, header=None, index_col=0)
                x_ref_i = torch.tensor(x_ref_i.values)
                x_refs.append(x_ref_i)
            wandb.log({"i": i})
            # mean win to tensor
            mean_win_repeated = torch.tensor(mean_win).repeat(x_doob_i.shape[0], 1)
            # mean_winとの距離を計算
            dists = torch.norm(x_doob_i - mean_win_repeated, dim=1)
            # 遠すぎる点の数をprint
            print("number of samples with distance > 10: ", (dists > 10).sum())
            # distsの平均を計算, print
            dists_mean = torch.mean(dists[dists < 10])
            print("mean distance between samples and mean_win: ", dists_mean)
            # x_ref_iとの距離を計算
            dists_ref = torch.norm(x_ref_i - mean_win_repeated, dim=1)
            # dists_refの平均を計算, print
            dists_ref_mean = torch.mean(dists_ref[dists_ref < 10])
            print("mean distance between samples and x_ref: ", dists_ref_mean)
            dists_ref_means_dict[doob_path] = dists_ref_mean.item()
    
    # dists_ref_means_dictを保存
    with open(os.path.join(dir, 'dists_ref_means_dict.json'), 'w') as f:
        json.dump(dists_ref_means_dict, f)
    
    x_doobs = torch.stack(x_doobs, axis=0)
    x_refs = torch.stack(x_refs, axis=0)
    plot(x_doobs, x_refs, sbm_path, potential_paths, dir)
    KDE_plot(x_doobs, x_refs, sbm_path, potential_paths, dir)

    print("### Finished doob h ###")

if False:
    def compute_doob_h_for_tensor(model, model_potential, x, t, step_size):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        doob_h_values = []
        for i in range(x.shape[0]):
            x[i] = x[i].requires_grad_(True)
            result_x, result_y = doob_h(model, model_potential, x[i].float(), t, step_size)
            doob_h_values.append(torch.stack([result_x, result_y]))
            del result_x, result_y
            gc.collect()
            torch.cuda.empty_cache()
        # Stack the results into a tensor
        return torch.stack(doob_h_values).to(device)

    def doob_langevin_monte_carlo(model, model_potential, num_samples, num_steps, step_size, T=10):
        tensorized = True
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # 初期サンプルを乱数から生成
        x = torch.randn(num_samples, model.input_dim).to(device)
        x_ref = x.clone()
        # モデルを推論モードに変更
        model.eval()
        model_potential.eval()
        # 以下、学習済みモデルによって予測されたスコアを用いてランジュバン・モンテカルロ法を実行
        for i in tqdm(range(num_steps), leave=False):
            t = T - T / num_steps * i ## t = T から, t = 0に行く
            batch_t = torch.ones((x.shape[0], 1)).to(device) * t
            ## x: (num_samples, input_dim)
            if tensorized:
                doob = doob_h_tensorized(model, model_potential, x, t, step_size)
                doob = torch.stack([doob[0], doob[1]], axis=1)
            else:
                doob = compute_doob_h_for_tensor(model, model_potential, x, t, step_size)
                doob.squeeze()
            
            with torch.no_grad():
                score = model(torch.cat([x, batch_t], axis=1))
                score_ref = model(torch.cat([x_ref, batch_t], axis=1))
                # 最終ステップのみノイズ無しでスコアの方向に更新
                if i < num_steps - 1:
                    noise = torch.randn(num_samples, model.input_dim).to(device)
                else:
                    noise = 0

                x     = x     + (x     + 2 * score + 2 * doob) * step_size + np.sqrt(2 * step_size) * noise
                x_ref = x_ref + (x_ref + 2 * score_ref)        * step_size + np.sqrt(2 * step_size) * noise
                # print("doob",doob)
                # print("score",score)
            del batch_t, doob, score, score_ref, noise
            gc.collect()
            torch.cuda.empty_cache()

        return x, x_ref


if __name__ == "__main__":
    main()