import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json

from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline

import sys
sys.path.append("/home/***/work/doob_apps/hug")

from src.models.model_potential import ModelPotential
from src.finetune.obj import objective_dpo
from src.finetune.inner_loop import Inner_Loop
from src.utils.set_seed import set_seed
from src.utils.img_func import show_images
from src.finetune.doob import doob_h, sampling_doob
from src.models.CT_model_predictor import RotationPredictorCNN
from src.models.CT_autoencoder import Autoencoder32

import matplotlib.pyplot as plt

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import wandb

import argparse
import copy

def load_config(config_path):
    """JSON形式の設定ファイルを読み込む"""
    with open(config_path, 'r') as file:
        config = json.load(file)
    return config


def main():
    print("## testing upperbound ##")
    # load potential model
    # 引数をパースする
    parser = argparse.ArgumentParser(description="Load a config file.")
    parser.add_argument(
        '--gpu', 
        type=str, 
        required=True, 
        help="Path to the config.json file"
    )
    # New argument for config file
    # parser.add_argument(
    #     '--doob',
    #     type=str,
    #     required=True,  # Make this argument required if needed
    #     help="Path to the config.json file"
    # )
    args = parser.parse_args()

    config_device = load_config(args.gpu)
    # config_doob = load_config(args.doob)

    device_num = config_device["device_num"]
    devices = config_device["devices"]
    print("device_num:", device_num)
    print("devices:", devices)

    set_seed(222 + device_num)

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    dirname = "/home/***/work/doob_apps/hug/outputs/figures/test_upperbound_CT/"+now_str+"_cu_"+str(device_num)
    os.makedirs(dirname, exist_ok=True)

    device = torch.device('cuda:'+str(device_num))  # メインのGPUデバイスを指定
    print("device:", device)

    # load config
    ############################################################################
    config_path = "/home/***/work/doob_apps/hug/configs/configs_CT.json"
    # config_doob_path = "/home/***/work/doob_apps/hug/configs/doob_sanity.json"
    ############################################################################

    with open(config_path, "r") as f:
        config = json.load(f)
    image_path = config["image_ref_path"]
    images = torch.load(image_path)
    model_path = config["model_path"]

    # config.update(config_doob)
    # decay_rate = config_doob["decay_rate"]
    # i_list = config_doob["i_list"]
    # batch_size = config_doob["batch_size"]
    # n_samples = config_doob["n_samples"]
    # n_samples_2 = config_doob["n_samples_2"]
    # potential_path = config_doob["potential_path"]
    # doob_interval = config_doob["doob_interval"]
    # doob_last_i = config_doob["doob_last_i"]
    mode = config["mode"]

    if mode == "CT":
        autoencoder = Autoencoder32()
        autoencoder.load_state_dict(torch.load(config["autoencoder_path"]))
        autoencoder.to(device)

    # dirnameに, i_list, decay_rateを保存
    with open(dirname+'/i_list_decay_rate.txt', mode='w') as f:
        f.write("device: "+str(device)+"\n")
        f.write("devices for parallelize: "+str(devices)+"\n")
        if False:
            f.write("i_list: "+str(i_list)+"\n")
            f.write("doob_interval: "+str(doob_interval)+"\n")
            f.write("doob_last_i: "+str(doob_last_i)+"\n")
            f.write("decay_rate: "+str(decay_rate)+"\n")
            f.write("batch_size: "+str(batch_size)+"\n")
            f.write("n_samples: "+str(n_samples)+"\n")
            f.write("n_samples_2: "+str(n_samples_2)+"\n")
            f.write("image_path: "+str(image_path)+"\n")
            f.write("model_path: "+str(model_path)+"\n")
            f.write("potential_path: "+str(potential_path)+"\n")

    wandb.init(
        project = "upperbound-test-CT",
        name = now_str + "_cu_"+str(device_num),
        config = config
    )

    # Load the butterfly pipeline
    pipeline = DDPMPipeline.from_pretrained(
        model_path
    ).to(device)

    unet = pipeline.unet.to(device)

    # unetを, 新しくtrainするために unet_training にコピー
    unet_training = copy.deepcopy(unet).to(device)

    noise_scheduler = pipeline.scheduler
    noise_scheduler.set_timesteps(1000)

    alpha_bar = noise_scheduler.alphas_cumprod[0]
    print("alpha_cumprod:", alpha_bar.shape)
    print(torch.sqrt(1-alpha_bar))

    latent_size = config["latent_size"]

    objective_class = objective_dpo(device=device, mode=mode)
    unet_training.train()
    autoencoder.eval()
    unet.eval()
    # training
    optimizer = torch.optim.Adam(unet_training.parameters(), lr=1e-7)
    upperbound_list = []
    for step in range(1000):
        optimizer.zero_grad()
        upperbound = objective_class._calc_upperbound(unet, noise_scheduler, None, mode, autoencoder, None, unet_training)
        upperbound.backward()
        optimizer.step()
        print("upperbound:", upperbound)
        upperbound_list.append(upperbound.item())
        wandb.log({"iteration": step, "upperbound": upperbound.item()})
        # 必要に応じて、学習率や勾配の大きさもトラック
        # for name, param in unet_training.named_parameters():
        #     if param.grad is not None:
        #         wandb.log({f"{name}_grad_norm": param.grad.norm().item()})
        # modelを保存
        if step % 5 == 0:
            torch.save(unet_training.state_dict(), os.path.join(dirname, f"unet_training_{step}.pth"))
    print(upperbound_list)

if __name__ == "__main__":
    main()