import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

import sys
sys.path.append("/home/***/work/doob_apps/hug")

from src.models.model_potential import ModelPotential
from src.finetune.obj import objective_dpo
from src.finetune.inner_loop import Inner_Loop
from src.utils.set_seed import set_seed
from src.utils.img_func import show_images
from src.finetune.doob import doob_h, sampling_doob

import matplotlib.pyplot as plt

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import wandb

import argparse

def load_config(config_path):
    """JSON形式の設定ファイルを読み込む"""
    with open(config_path, 'r') as file:
        config = json.load(file)
    return config


def main():
    print("## testing Doob inner loop ##")
    # load potential model
    # 引数をパースする
    parser = argparse.ArgumentParser(description="Load a config file.")
    parser.add_argument(
        '--gpu', 
        type=str, 
        required=True, 
        help="Path to the config.json file"
    )
    # New argument for config file
    parser.add_argument(
        '--doob',
        type=str,
        required=True,  # Make this argument required if needed
        help="Path to the config.json file"
    )
    args = parser.parse_args()

    config_device = load_config(args.gpu)
    config_doob = load_config(args.doob)

    device_num = config_device["device_num"]
    devices = config_device["devices"]
    print("device_num:", device_num)
    print("devices:", devices)

    set_seed(222 + device_num)

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_doob/"+now_str+"_cu_"+str(device_num)
    os.makedirs(dirname, exist_ok=True)

    device = torch.device('cuda:'+str(device_num))  # メインのGPUデバイスを指定
    print("device:", device)

    # load config
    ############################################################################
    config_path = "/home/***/work/doob_apps/hug/configs/configs.json"
    # config_doob_path = "/home/***/work/doob_apps/hug/configs/doob_sanity.json"
    ############################################################################

    with open(config_path, "r") as f:
        config = json.load(f)
    image_path = config["image_ref_path"]
    images = torch.load(image_path)
    model_path = config["model_path"]

    config.update(config_doob)
    decay_rate = config_doob["decay_rate"]
    i_list = config_doob["i_list"]
    batch_size = config_doob["batch_size"]
    n_samples = config_doob["n_samples"]
    n_samples_2 = config_doob["n_samples_2"]
    potential_path = config_doob["potential_path"]
    doob_interval = config_doob["doob_interval"]
    doob_last_i = config_doob["doob_last_i"]

    # dirnameに, i_list, decay_rateを保存
    with open(dirname+'/i_list_decay_rate.txt', mode='w') as f:
        f.write("device: "+str(device)+"\n")
        f.write("devices for parallelize: "+str(devices)+"\n")
        f.write("i_list: "+str(i_list)+"\n")
        f.write("doob_interval: "+str(doob_interval)+"\n")
        f.write("doob_last_i: "+str(doob_last_i)+"\n")
        f.write("decay_rate: "+str(decay_rate)+"\n")
        f.write("batch_size: "+str(batch_size)+"\n")
        f.write("n_samples: "+str(n_samples)+"\n")
        f.write("n_samples_2: "+str(n_samples_2)+"\n")
        f.write("image_path: "+str(image_path)+"\n")
        f.write("model_path: "+str(model_path)+"\n")
        f.write("potential_path: "+str(potential_path)+"\n")

    wandb.init(
        project = "doob-test-butterflies-2",
        name = now_str + "_cu_"+str(device_num)+"_"+args.doob,
        config = config
    )

    # Load the butterfly pipeline
    pipeline = DDPMPipeline.from_pretrained(
        model_path
    ).to(device)

    unet = pipeline.unet.to(device)

    noise_scheduler = pipeline.scheduler
    noise_scheduler.set_timesteps(1000)

    alpha_bar = noise_scheduler.alphas_cumprod[0]
    print("alpha_cumprod:", alpha_bar.shape)
    print(torch.sqrt(1-alpha_bar))

    images_batch = images[:2].to(device)
    sample = torch.randn_like(images_batch[0].unsqueeze(0)).to(device)
    with torch.no_grad():
        residual = unet(sample, 999)["sample"]
    # residualのtypeを確認
    print("residual:", residual)
    print("residual type:", type(residual))
    print("residual shape:", residual.shape)
    sample_prev = noise_scheduler.step(residual, 10, sample).prev_sample
    # count time
    start = datetime.datetime.now()
    sample_0 = noise_scheduler.step(residual, 999, sample).pred_original_sample
    sample_1 = noise_scheduler.step(residual, 999, sample+0.01).pred_original_sample
    img_0 = show_images(sample_0)
    plt.figure()
    plt.imshow(img_0)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly_0.png")
    img_1 = show_images(sample_1)
    plt.figure()
    plt.imshow(img_1)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly_1.png")

    end = datetime.datetime.now()
    print("time of pred_origin_sample:", end-start)

    # new potential model
    newPotential = ModelPotential().to(device)
    if torch.cuda.device_count() > 1:
        newPotential = torch.nn.DataParallel(newPotential, device_ids=devices, output_device=device_num)
    newPotential.load_state_dict(torch.load(potential_path))

    newPotential.eval()
    with torch.no_grad():
        print("potential:", newPotential(sample).shape)

    start = datetime.datetime.now()
    db = doob_h(unet, noise_scheduler, newPotential, sample, 500, device="cuda:"+str(device_num), config=config)
    alpha_bar = noise_scheduler.alphas_cumprod[500]
    db = torch.sqrt(1-alpha_bar) * db
    end = datetime.datetime.now()
    print("time of doob_h:", end-start)
    print("mean of db:", torch.mean(db))
    print("std of db:", torch.std(db))
    print("db.shape:", db.shape)
    # 時間を測る
    start = datetime.datetime.now()
    samples = sampling_doob(unet, noise_scheduler, newPotential, doob_interval=doob_interval, i_list=i_list, decay_rate=decay_rate, batch_size=batch_size, \
                            dirname=dirname, device=device, config_device=config_device, config=config)
    end = datetime.datetime.now()
    print("time:", end-start)
    plt.figure()
    img = show_images(samples)
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/butterfly_doob.png")

    wandb.finish()
    print("## finish testing Doob inner loop ##")

if __name__ == "__main__":
    main()