import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import datetime, gc
import os, json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

from torch.autograd.functional import jacobian
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

import sys
sys.path.append("/home//work/doob_apps/hug")

from src.utils.img_func import show_images, make_grid, preprocess, transform
from src.utils.set_seed import set_seed
from src.models.CT_autoencoder import Autoencoder32

import warnings

import wandb

def load_config(mode):
    if mode == "buttefly":
        path = "/home//work/doob_apps/hug/configs/configs.json"
        path2 = "/home//work/doob_apps/hug/configs/doob.json"
    elif mode == "CT":
        path = "/home//work/doob_apps/hug/configs/configs_CT.json"
        path2 = "/home//work/doob_apps/hug/configs/doob_CT.json"
    # configファイルを読み込む
    with open(path, "r") as f:
        config = json.load(f)
    with open(path2, "r") as f:
        config2 = json.load(f)
    config.update(config2)
    return config

def replace_nan_with_zero(tensor):
    # Check for NaN values
    nan_mask = torch.isnan(tensor)
    nan_count = nan_mask.sum().item()
    
    if nan_count > 0:
        warnings.warn(f"Found {nan_count} NaN values in the tensor. Replacing them with 0.")

    # Replace NaN values with 0
    return torch.nan_to_num(tensor, nan=0.0)

def make_id_tensor(batch_size, channels=3, height=32, width=32):
    # 各次元のインデックステンソルを作成
    i = torch.arange(batch_size).view(batch_size, 1, 1, 1, 1, 1, 1, 1)
    j = torch.arange(channels).view(1, channels, 1, 1, 1, 1, 1, 1)
    k = torch.arange(height).view(1, 1, height, 1, 1, 1, 1, 1)
    l = torch.arange(width).view(1, 1, 1, width, 1, 1, 1, 1)

    m = torch.arange(batch_size).view(1, 1, 1, 1, batch_size, 1, 1, 1)
    n = torch.arange(channels).view(1, 1, 1, 1, 1, channels, 1, 1)
    o = torch.arange(height).view(1, 1, 1, 1, 1, 1, height, 1)
    p = torch.arange(width).view(1, 1, 1, 1, 1, 1, 1, width)

    # 条件に基づいてテンソルを作成
    tensor = ((j == n) & (k == o) & (l == p)).float()# ((i == m) & (j == n) & (k == o) & (l == p)).float()

    del i, j, k, l, m, n, o, p
    gc.collect()
    torch.cuda.empty_cache()

    return tensor

def compute_conditional_expectation(unet, scheduler, model_potential, x, t, n_samples = 4, n_samples_2 = 4, device = "cuda", strongness = 10.0, mode="butterfly", autoencoder=None):
    """
    x: torch.Tensor (batch_size, n_samples, 3, 32, 32)
    t: int
    This function computes samples for "E[h_T(x_T) | x_t]".
    """
    model_potential.eval()
    unet.eval()

    # Random starting point (8 random images):
    batch_size = x.shape[0]
    if n_samples >= 16:
        mini_n = n_samples // 2
    else:
        mini_n = n_samples

    print("n_samples:", n_samples)
    print("mini_n:", mini_n)

    pot_x = torch.zeros(batch_size, n_samples).to(device)
    print("## compute conditional expectation, t:", t)
    for k in tqdm(range(batch_size)):
        print("k:", k)
        for l in range(0, n_samples, mini_n):
            print("l:", l)
            sample_kl = x[k][l:l+mini_n]
            print("sample_kl.shape", sample_kl.shape)
            with torch.no_grad():
                sample_kl = scheduler.step(unet(sample_kl, t)["sample"], t, sample_kl).pred_original_sample
                if mode == "CT":
                    sample_kl = autoencoder.decoder(sample_kl)
                # print("sample_kl:", sample_kl)
                pot_x_tmp = model_potential(sample_kl)
                if mode == "CT":
                    pot_x_tmp = pot_x_tmp.squeeze(1)
                pot_x_tmp = strongness * (torch.clamp(pot_x_tmp, min = -5.0, max = 5.0) - torch.mean(pot_x_tmp))
                print("pot_x_tmp[k] mean: ", torch.mean(pot_x_tmp))
                print("pot_x_tmp[k] std: ", torch.std(pot_x_tmp))
                # print("pot_x_tmp: ", pot_x_tmp)
            pot_x[k][l:l+mini_n] = torch.exp(-pot_x_tmp+1).unsqueeze(0)
            del pot_x_tmp, sample_kl
            gc.collect()
            torch.cuda.empty_cache()

    del x
    gc.collect()
    torch.cuda.empty_cache()

    print("check pot_x")
    pot_x = replace_nan_with_zero(pot_x)

    return pot_x

def doob_h(unet, noise_scheduler, model_potential, x, t, device="cuda", config=None, eps=1e-10, strongness=10.0, mode="butterfly", autoencoder=None):
    """
    x: torch.Tensor (batch_size, 3, 32, 32)
    t: int (0 <= t <= 1000, 1000: noise, 0: no noise)
    This function computes "nabla log h_t(x)".
    """
    # configを読み込む
    if config is None:
        config = load_config()
    
    if mode == "butterfly":
        image_size = config["image_size"]
    elif mode == "CT":
        image_size = config["latent_size"]
    channels = config["channels"]
    model_path = config["model_path"]
    n_samples = config["n_samples"]
    n_samples_2 = config["n_samples_2"]
    doob_interval = config["doob_interval"]

    if 1 <= t:
        step_size = 0.5 * torch.abs(torch.log(1-torch.tensor(noise_scheduler.betas[t]))) + eps # betaに基づいてちゃんと設定する
    else:
        step_size = noise_scheduler.betas[t]
    step_size_tensor = torch.tensor(step_size).squeeze().clone().detach().to(device)
    print("step_size:", step_size)

    # x に関する微分を計算するための関数を定義
    def output_wrt_x(x):
        return unet(x, t)["sample"]
    residuals = []

    batch_size = x.shape[0]
    print("batch_size:", batch_size)

    print("compute residuals")
    with torch.no_grad():
        if batch_size >= 4:
            mini_batch_size = x.shape[0] // 4
            assert mini_batch_size * 4 == x.shape[0]
        else:
            mini_batch_size = x.shape[0]
        # xをmini_batch_sizeごとに分割して, butterfly_pipeline.unet に入力
        for i in tqdm(range(0, x.shape[0], mini_batch_size)):
            sample = x[i:i+mini_batch_size]
            sample = sample.to(device)
            residual = unet(sample, t)["sample"]
            sample = noise_scheduler.step(residual, t, sample).prev_sample
            residuals.append(residual)
            del residual, sample
            gc.collect()
            torch.cuda.empty_cache()
        residuals = torch.cat(residuals, dim=0)

    unet.train()
    # gradientをリセット
    unet.zero_grad()
    # x を勾配計算可能な tensor に変換
    x = x.requires_grad_(True).to(device)

    # 出力に対する x のヤコビアンを計算
    jacobian_matrix_x_T = None

    print("calculate jacobian")
    calc_j = True
    if calc_j:
        for i in tqdm(range(batch_size)):
            sample_i = x[i].requires_grad_(True).to(device).unsqueeze(0)
            jacobian_matrix_x_i = jacobian(output_wrt_x, sample_i)
            print("jacobian_matrix_x_i.shape:", jacobian_matrix_x_i.shape)  # ヤコビアンの形状を確認
            # jacobian_matrix_xに追加
            if jacobian_matrix_x_T is None:
                jacobian_matrix_x_T = jacobian_matrix_x_i.permute(4,5,6,7,0,1,2,3)
            else:
                jacobian_matrix_x_T = torch.cat([jacobian_matrix_x_T, jacobian_matrix_x_i.permute(4,5,6,7,0,1,2,3)], dim=0)
                del jacobian_matrix_x_i, sample_i
                gc.collect()
                torch.cuda.empty_cache()
    else:
        print("rand jacobian")
        jacobian_matrix_x_T = torch.randn(batch_size, channels, image_size, image_size, 1, channels, image_size, image_size).to(device)

    ########################################
    unet.eval()        
    # x を勾配計算のいらない tensor に変換
    if mode == "butterfly":
        x = x.detach().to(device)
    elif mode == "CT":
        x = x.clone().detach().to(device)
    # jacobian_matrix_x_T を detach して勾配計算のいらない tensor に変換
    # jacobian_matrix_x_T = jacobian_matrix_x_T.detach()
    ########################################
    # jacobian_matrix_x_T のなかにNaNがあれば, 0にする
    print("check jacobian")
    jacobian_matrix_x_T = replace_nan_with_zero(jacobian_matrix_x_T)

    ## noise を scoreに変換
    # alpha barを計算
    alpha_bar = noise_scheduler.alphas_cumprod[t]
    # alpha_barがサイズ1のテンソルである子を確認
    jacobian_matrix_x_T = - jacobian_matrix_x_T / torch.sqrt(1-alpha_bar + 1e-10)
    assert jacobian_matrix_x_T.shape == (batch_size, channels, image_size, image_size, 1, channels, image_size, image_size)

    mean = (torch.exp(step_size_tensor) * x + (-2 * residuals) * (1 - torch.exp(-step_size_tensor))).to(device)
    std = torch.sqrt(torch.exp(2 * step_size_tensor) - 1).to(device)

    # print("mean.shape:", mean.shape)  # meanの形状を確認  

    # print("jacobian_matrix_x_T.shape:", jacobian_matrix_x_T.shape)  # 転置されたヤコビアンの形状を確認

    # x.shape[0]の単位行列
    identity_tensor = make_id_tensor(x.shape[0], channels=channels, height=image_size, width=image_size).to(device)
    identity_tensor = identity_tensor.repeat(batch_size, 1, 1, 1, 1, 1, 1, 1)
    assert identity_tensor.shape == (batch_size, channels, image_size, image_size, 1, channels, image_size, image_size)
    # assert が失敗する場合, identity_tensorの形状を確認
    if identity_tensor.shape != (batch_size, channels, image_size, image_size, 1, channels, image_size, image_size):
        print("identity_tensor.shape:", identity_tensor.shape)
        ValueError("identity_tensor.shape is not correct.")
    # print("identity_tensor.shape:", identity_tensor.shape)  # 単位行列の形状を確認
    deriv_mean_tensor= torch.exp(step_size_tensor) * identity_tensor + (-2 * jacobian_matrix_x_T) * (1 - torch.exp(-step_size_tensor))
    assert deriv_mean_tensor.shape == (batch_size, channels, image_size, image_size, 1, channels, image_size, image_size)
    # print("deriv_mean_tensor.shape:", deriv_mean_tensor.shape)  # deriv_mean_tensorの形状を確認

    mean = mean.unsqueeze(1).repeat(1, n_samples, 1, 1, 1)
    x_rand = torch.randn(batch_size, n_samples, channels, image_size, image_size).to(device) * std + mean
    assert x_rand.shape == mean.shape

    deriv_mean_tensor = deriv_mean_tensor.repeat(1, 1, 1, 1, n_samples, 1, 1, 1)
    assert deriv_mean_tensor.shape == (batch_size, channels, image_size, image_size, n_samples, channels, image_size, image_size)
    # print("mean.shape:", mean.shape)  # meanの形状を確認

    diff = x_rand - mean
    diff = diff.view(batch_size, 1, 1, 1, n_samples, channels, image_size, image_size)
    diff = diff.repeat(1, channels, image_size, image_size, 1, 1, 1, 1)
    mean_diff = deriv_mean_tensor * diff
    mean_diff = torch.sum(mean_diff, dim=[5,6,7])
    # mean_diff.shape: torch.Size([batch_size, 3, 32, 32, n_samples]) になるはず
    assert mean_diff.shape == (batch_size, channels, image_size, image_size, n_samples)
    mean_diff = replace_nan_with_zero(mean_diff)

    conditioned_samples = compute_conditional_expectation(unet, noise_scheduler, model_potential, x_rand, t, \
                                                          n_samples=n_samples, n_samples_2=n_samples_2, device=device, strongness=strongness, mode=mode, autoencoder=autoencoder)
    # conditioned_samples.shape: torch.Size([batch_size, n_samples]) になるはず
    assert conditioned_samples.shape == (batch_size, n_samples)
    conditioned_samples = replace_nan_with_zero(conditioned_samples)

    conditioned_samples = conditioned_samples.view(batch_size, 1, 1, 1, n_samples)

    # mean_diffの平均, 分散をprint
    print("mean_diff mean:", torch.mean(mean_diff))
    print("mean_diff std:", torch.std(mean_diff))

    md_cs = mean_diff * conditioned_samples
    assert md_cs.shape == (batch_size, channels, image_size, image_size, n_samples)
    # print("md_cs.shape:", md_cs.shape)  # md_csの形状を確認
    md_cs = replace_nan_with_zero(md_cs)

    sum = torch.sum(md_cs, dim=4)
    denom = torch.sum(conditioned_samples, dim=4) + eps
    denom = denom.repeat(1, channels, image_size, image_size)

    # sumとdenomの平均, 分散をprint
    print("sum mean:", torch.mean(sum))
    print("sum std:", torch.std(sum))
    print("denom mean:", torch.mean(denom))
    print("denom std:", torch.std(denom))

    # print("sum.shape:", sum.shape)  # sumの形状を確認
    # print("denom.shape:", denom.shape)  # denomの形状を確認

    doob =  sum / denom

    del x, residuals, jacobian_matrix_x_T, mean, x_rand, deriv_mean_tensor, mean_diff, conditioned_samples, md_cs,\
        diff, identity_tensor, sum, denom
    gc.collect()
    torch.cuda.empty_cache()

    return doob

def sampling_doob(unet, noise_scheduler, model_potential, batch_size=4, decay_rate = 0.90, doob_interval = 10, i_list=[990],\
                  device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), dirname=None, config_device=None, config=None, mode="butterfly"):
    # configを読み込む
    if config is None:
        config = load_config()

    if mode == "butterfly":
        image_size = config["image_size"]
    elif mode == "CT":
        image_size = config["latent_size"]
    strongness = config["strongness"]
    doob_last_i = config["doob_last_i"]
    channels = config["channels"]
    if mode == "CT":
        autoencoder_path = config["autoencoder_path"]
        autoencoder = Autoencoder32()
        autoencoder.load_state_dict(torch.load(autoencoder_path))
        autoencoder.to(device)
        autoencoder.eval()
    elif mode == "butterfly":
        autoencoder = None
        
    if config_device is not None:
        device_num = config_device["device_num"]
    else:
        device_num = 0

    # model_path = config["model_path"]
    decay_rate = 0.90
    decay_0 = 1
    decay = decay_0
    seed_magic_num = config["seed_magic_num"]
    set_seed(10000+device_num+seed_magic_num)
    sample = torch.randn(batch_size, channels, image_size, image_size).to(device)
    sample_ref = sample.clone()
    doob_h_term = torch.zeros(batch_size, channels, image_size, image_size).to(device)
    print("## sampling doob ##")
    for i, t_tmp in enumerate(noise_scheduler.timesteps):
        if (i % 100 == 0) or (i == 999):
            wandb.log({"doob_step": i})
            print("i:", i)
            plt.figure()
            if mode == "CT":
                encoded_sample = autoencoder.decoder(sample)
                encoded_sample_ref = autoencoder.decoder(sample_ref)
            elif mode == "butterfly":
                encoded_sample = sample
                encoded_sample_ref = sample_ref
            img = show_images(encoded_sample, mode=mode)
            plt.imshow(img)
            plt.axis("off")
            plt.savefig(dirname + "/doob_"+str(i)+".png")
            plt.close()
            plt.figure()
            img_ref = show_images(encoded_sample_ref, mode=mode)
            plt.imshow(img_ref)
            plt.axis("off")
            plt.savefig(dirname + "/ref_"+str(i)+".png")
            plt.close()
            # img, img_refをwandbに保存
            wandb.log({"doob_sample": wandb.Image(img, caption="doob, i: "+str(i))})
            wandb.log({"doob_ref": wandb.Image(img_ref, caption="ref, i: "+str(i))})
            del img, img_ref
        # Get model pred
        with torch.no_grad():
            set_seed(t_tmp * 100 + device_num + seed_magic_num)
            residual = unet(sample, t_tmp)["sample"]
            set_seed(t_tmp * 100 + device_num + seed_magic_num)
            residual_ref = unet(sample_ref, t_tmp)["sample"]
        if i % doob_interval == 0 and i <= doob_last_i: # i == 0で, サンプルをすべきかどうか?
            print("calc doob, i:", i)
            # 時間を測る
            start = datetime.datetime.now()
            doob_h_term = doob_h(unet, noise_scheduler, model_potential, sample, t_tmp, device=device, config=config, strongness=strongness, \
                                 mode=mode, autoencoder=autoencoder)
            end = datetime.datetime.now()
            print("calculating doob time:", end-start)
            decay = decay_0
        decay *= decay_rate
        alpha_bar = noise_scheduler.alphas_cumprod[t_tmp]
        # residual_doobと, residualの差を計算
        sample_doob = sample + decay * torch.sqrt(alpha_bar) * doob_h_term
        # Update sample with step
        set_seed(t_tmp * 10 + device_num + seed_magic_num)
        sample = noise_scheduler.step(residual, t_tmp, sample_doob).prev_sample
        set_seed(t_tmp * 10 + device_num + seed_magic_num)
        sample_ref = noise_scheduler.step(residual_ref, t_tmp, sample_ref).prev_sample
        del residual, residual_ref, sample_doob
        gc.collect()
        torch.cuda.empty_cache()

    return sample

