import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
import torchvision
from datasets import load_dataset
from torchvision import transforms

import matplotlib.pyplot as plt

import base64
from io import BytesIO

import datetime

import os, json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

from torch.autograd.functional import jacobian
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

import sys
sys.path.append("/home//work/doob_apps/hug")

from src.utils.img_func import show_images, make_grid, preprocess, transform

def load_config():
    path = "/home//work/doob_apps/hug/configs/configs.json"
    # configファイルを読み込む
    with open(path, "r") as f:
        config = json.load(f)
    return config

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dirname = "/home//work/doob_apps/hug/outputs/samples"
    # 日付を取得
    dt_now = datetime.datetime.now()
    # フォルダ名を作成
    dirname = os.path.join(dirname, dt_now.strftime("%Y%m%d_%H%M%S"))
    # フォルダを作成
    os.makedirs(dirname, exist_ok=True)

    # configを読み込む
    config = load_config()
    image_size = config["image_size"]
    model_path = config["model_path"]

    # Load the butterfly pipeline
    butterfly_pipeline = DDPMPipeline.from_pretrained(
        model_path
    ).to(device)

    noise_scheduler = butterfly_pipeline.scheduler
    noise_scheduler.set_timesteps(1000)

    modify = False
    if modify:
        # beta_endを変更してスケジューラの再初期化
        noise_scheduler = DDPMScheduler(num_train_timesteps=500, beta_end=0.1)

    # スケジューラのパラメータを辞書形式に変換
    scheduler_params = {
        "num_train_timesteps": noise_scheduler.config.num_train_timesteps,
        "beta_start": noise_scheduler.config.beta_start,
        "beta_end": noise_scheduler.config.beta_end,
        "beta_schedule": noise_scheduler.config.beta_schedule,
        "trained_betas": noise_scheduler.config.trained_betas,
        "variance_type": noise_scheduler.config.variance_type
    }

    # JSON形式で出力
    scheduler_json = json.dumps(scheduler_params, indent=4)
    # scheduler_paramsをjsonで保存
    with open(os.path.join(dirname, "scheduler_params.json"), "w") as f:
        f.write(scheduler_json)

    # Random starting point (8 random images):
    batch_size = 64
    num_iterations = 100
    samples = []
    for k in tqdm(range(num_iterations)):
        sample = torch.randn(batch_size, 3, 32, 32).to(device)
        for i, t in enumerate(noise_scheduler.timesteps):
            # Get model pred
            with torch.no_grad():
                residual = butterfly_pipeline.unet(sample, t)["sample"]  # model(sample, t).sample
            # Update sample with step
            sample = noise_scheduler.step(residual, t, sample).prev_sample
        if k == 0:
            plt.figure(figsize=(6, 6))
            img = show_images(sample)
            plt.imshow(img)
            plt.axis("off")
            plt.savefig(dirname + "/butterfly.png")
        samples.append(sample)
    samples = torch.cat(samples, dim=0)

    # samplesをtorch.saveで保存
    torch.save(samples, os.path.join(dirname, "samples_ref.pth"))
    # batch_size, num_iterationsをjsonで保存
    with open(os.path.join(dirname, "config.json"), "w") as f:
        json.dump({
            "batch_size": batch_size,
            "num_iterations": num_iterations,
            "model_path": model_path
        }, f)

if __name__ == "__main__":
    main()