import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import datetime
import json, gc

import sys
sys.path.append("/home//work/doob_apps/hug")

from src.models.model_potential import ModelPotential
from src.finetune.obj import objective_dpo
from src.finetune.inner_loop import Inner_Loop
from src.utils.set_seed import set_seed


from src.models.CT_model_predictor import RotationPredictorCNN

import matplotlib.pyplot as plt

import wandb

def show_samples(newPotential, idx_loop, dirname, device, mode):
    if mode == "butterfly":
        images = torch.load("/home//work/doob_apps/hug/outputs/samples/20240910_104356/samples_ref.pth")
    elif mode == "CT":
        images = torch.load("/home//work/doob_apps/hug/outputs/CT_diffusion/20240919_1459/decoded_images.pth")
    batch_size = 64
    images_batch = images[:batch_size].to(device)
    # それぞれのimgに対して, objを表示する
    # newpot_batch = newPotential(images_batch)
    # minibatchごとにnewPotentialを計算し, torch.catで結合
    newpot_batch = torch.cat([newPotential(images_batch[i:i+8]) for i in range(0, len(images_batch), 8)], dim=0)
    if mode == "CT":
        newpot_batch = newpot_batch.squeeze()
    # newpot_batchの小さい順に並び替え
    newpot_batch, indices = torch.sort(newpot_batch)
    images_batch = images_batch[indices]
    img = show_images(images_batch, mode = mode)
    plt.imshow(img)
    plt.axis("off")
    plt.savefig(dirname + "/ct_preferred_"+str(idx_loop)+".png")
    # wandb.log({"ct_"+str(idx_loop): [wandb.Image(img)]})
    if idx_loop < 3 or idx_loop % 10 == 0:
        try:
            wandb.log({"ct_"+str(idx_loop): [wandb.Image(images_batch[i], caption=f'preference: {newpot_batch[i]}') for i in range(images_batch.size(0))]})
        except Exception as e:
            print("wandb.log failed")

def main():
    print("## Dual averaging ##")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device_ids = [0, 1]

    set_seed(222)
    mode = "butterfly"
    device = "cuda"
    tot_loop = 100
    print("tot_loop: ", tot_loop)
    print("mode: ", mode)

    if mode == "butterfly":
        config_path = '/home//work/doob_apps/hug/configs/configs.json'
        with open(config_path) as f:
            config = json.load(f)
        preferred_color = config["preferred_color"]
        config_train_path = '/home//work/doob_apps/hug/configs/train_potential.json'
        with open(config_train_path) as f:
            config_train = json.load(f)
        # config に config_trainを追加
        config.update(config_train)
        config_reg_path = '/home//work/doob_apps/hug/configs/regularization.json'
        with open(config_reg_path) as f:
            config_reg = json.load(f)
        # config に config_regを追加
        config.update(config_reg)
    elif mode == "CT":
        config_path = '/home//work/doob_apps/hug/configs/configs_CT.json'
        with open(config_path) as f:
            config = json.load(f)
        config_train_path = '/home//work/doob_apps/hug/configs/train_potential_CT.json'
        with open(config_train_path) as f:
            config_train = json.load(f)
        config_reg_path = '/home//work/doob_apps/hug/configs/regularization_CT.json'
        with open(config_reg_path) as f:
            config_reg = json.load(f)
        # config に config_regを追加
        config.update(config_reg)
        predictor_path = config["predictor_path"]
    beta = config["beta"]
    gamma = config["gamma"]
    y_min = config["y_min"]
    y_max = config["y_max"]
    pot_min = config["pot_min"]
    pot_max = config["pot_max"]

    # セーブ用のディレクトリを作成
    now = datetime.datetime.now()
    now_str = now.strftime('%Y%m%d_%H%M%S')
    if mode == "butterfly":
        dirname = "/home//work/doob_apps/hug/outputs/figures/DA/"+now_str
    elif mode == "CT":
        dirname = "/home//work/doob_apps/hug/outputs/figures/DA_CT"+now_str
    os.makedirs(dirname, exist_ok=True)

    objective_class = objective_dpo(device=device, mode=mode) 

    # idx_loopとbetaをtxtファイルに保存
    with open(dirname+'/idx_loop_beta.txt', mode='w') as f:
        f.write("tot_loop"+str(tot_loop)+"\n")
        f.write("mode: "+str(mode)+"\n")
        f.write("beta: "+str(beta)+"\n")
        f.write("gamma: "+str(gamma)+"\n")
        f.write("y_min: "+str(y_min)+"\n")
        f.write("y_max: "+str(y_max)+"\n")
        f.write("pot_min: "+str(pot_min)+"\n")
        f.write("pot_max: "+str(pot_max)+"\n")
        if mode == "butterfly":
            f.write("preferred_color: "+str(preferred_color)+"\n")
        elif mode == "CT":
            f.write("predictor_path: "+str(predictor_path)+"\n")

    wandb.init(
        project="Dual Averaging for "+str(mode),
        config=config
    )

    # previousPotential = ApproxPotential(2).to(device)

    if mode == "butterfly":
        previousPotential = ModelPotential().to(device)
    elif mode == "CT":
        previousPotential = RotationPredictorCNN().to(device)
    # DataParallelを使用して並列化
    if torch.cuda.device_count() > 1:
        previousPotential = torch.nn.DataParallel(previousPotential, device_ids=device_ids, output_device=device)

    idx_loop_list = []
    obj_list = []
    dpo_list = []
    kl_list = []
    for idx_loop in range(1, tot_loop+1):
        potential_for_inner_loop = Inner_Loop(beta=beta, device=device, mode=mode, device_ids=device_ids)
        newPotential = potential_for_inner_loop.train_potential(previousPotential, idx_loop, savedirname=dirname, plot_data=True)
        previousPotential = newPotential
        # before_obj, before_dpo, before_kl = objective_class.reglarized_objective(previousPotential, model_sbm, beta)
        after_obj, after_dpo, after_kl = objective_class.reglarized_objective(newPotential, beta)
        show_samples(newPotential, idx_loop, dirname, device, mode)
        wandb.log({"idx_loop": idx_loop, "regularized objective": after_obj, "DPO": after_dpo, "beta * KL": after_kl})
        # idx_loop, after_obj, after_dpo, after_klのtypeを確認
        print("idx_loop: ", idx_loop)
        print("obj: ", after_obj)
        print("dpo: ", after_dpo)
        print("kl: ", after_kl)
        idx_loop_list.append(idx_loop)
        obj_list.append(after_obj)
        dpo_list.append(after_dpo)
        kl_list.append(after_kl)
        del potential_for_inner_loop, newPotential
        gc.collect()
        torch.cuda.empty_cache()

    # idx_loop_list, obj_list, dpo_list, kl_listをcsvに保存
    df = pd.DataFrame({"idx_loop": idx_loop_list, "regularized objective": obj_list, "DPO": dpo_list, "beta * KL": kl_list})
    df.to_csv(dirname+"/DA_result.csv")

    wandb.alert(title="DA finished", text="DA finished")
    wandb.finish()

    print("## DA end ##")

if __name__ == "__main__":
    main()