import os
import torch
import datetime
import random
import numpy as np
from argparse import ArgumentParser
from scene_pyg_loader import ScenePyGLoaderSeq, text_ball_collate_fn_pyg
from models.text_grad_gnn import AssembleModel_Room, GradientFieldSampler
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from utils import create_directory, get_model_args, get_sampler_param, get_text_emb_handler #, save_network
from text_emb_handler import SentenceTransformerHandler
from arrange_tools import arrange_room_mesh, render_transformed_obj, arrange_3d_to_2d, arrange_2d_to_3d, render_frames
from mitsuba_render_func import mi_write_img

parser = ArgumentParser()
parser.add_argument("--seed", type=int, default=100)
parser.add_argument("--dataset_dir", type=str, default="my_dataset")
parser.add_argument("--text_cache_dir", type=str, default="text_cache") # cache folder for text embedding
parser.add_argument("--model_dir", type=str, default="model.pt") # model folder

# Sampling args
parser.add_argument("--sampler", type=str, default="PC")
parser.add_argument("--num_steps", type=int, default=500)
parser.add_argument("--t0", type=float, default=1.0)
parser.add_argument("--snr", type=float, default=0.16)

# Text embedding model
# parser.add_argument("--text_emb_model_id", type=str, default="paraphrase-MiniLM-L12-v2")

# Output directory
parser.add_argument("--model3d_base_dir", type=str, default="models")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--render_cache_dir", type=str, default="cache_3d_render")

args = parser.parse_args()

# control randomness
if args.seed >= 0:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


test_dataset_dir = os.path.join(args.dataset_dir, "test")
train_cache_dir = os.path.join(args.text_cache_dir, "train")
test_cache_dir = os.path.join(args.text_cache_dir, "test")

test_dataset = ScenePyGLoaderSeq(test_dataset_dir, test_cache_dir)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=text_ball_collate_fn_pyg)

loader_len = len(test_dataloader)
print("Number of data in the test dataset: ", loader_len)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# load language embedding model
# emb_handler = SentenceTransformerHandler(DEVICE, args.text_emb_model_id)

# Load the model
model_dict = torch.load(args.model_dir, map_location="cpu", weights_only=False) # we first place the model on the cpu

model_args = model_dict["model_args"]
print("Model args: ", model_args)
model = AssembleModel_Room(model_args).to(DEVICE)
model.load_state_dict(model_dict["model"])
model.eval()

# load text embedding model
print("Your text embedding model: ", model_args["text_emb_model_id"])
emb_handler = get_text_emb_handler(model_args["text_emb_model_id"], DEVICE)

# def return_data(data):
#     for i, data in enumerate(test_dataloader):
#         if i == data_idx:
#             return data

sampler_param = get_sampler_param(args)
sampler = GradientFieldSampler(sampler_param, model, DEVICE)

render_cache_dir = os.path.join(args.output_dir, args.render_cache_dir)
create_directory(render_cache_dir, force=True)

# while True:
#     # receive the user's input
#     data_idx = input("Enter the index of the data you want to rearrange: ")
#     try:
#         data_idx = int(data_idx)
#         if data_idx < 0 or data_idx >= loader_len:
#             print("out of range. Please enter a valid index.")
#             continue
#     except ValueError:
#         print("It must be an integer. Please enter a valid index.")
#         continue
#     print("Data index: ", data_idx)

#     create_directory(render_cache_dir, force=True)
gt_images_dir = os.path.join(args.output_dir, "gt_images")
pred_images_dir = os.path.join(args.output_dir, "pred_images")
text_dir = os.path.join(args.output_dir, "text")
pos_states_dir = os.path.join(args.output_dir, "pos_states")
bbox_dir = os.path.join(args.output_dir, "bbox")
create_directory(gt_images_dir, force=True)
create_directory(pred_images_dir, force=True)
create_directory(text_dir, force=True)
create_directory(pos_states_dir, force=True)
create_directory(bbox_dir, force=True)

for data_idx, batch in enumerate(test_dataloader):
    
    # print("text_emb: ", text_emb.size())

    # batch = return_data(data_idx)

    pyg_data = Batch.from_data_list(batch["data"]).to(DEVICE)
    # print("pyg_data: ", pyg_data)

    print("Irregular data: ", batch["irregular_data"])

    your_prompt = batch["irregular_data"]["text_des"][0] # Only for batch size = 1!
    print("Your prompt: ", your_prompt)

    text_emb = emb_handler.get_sentence_embedding([your_prompt])

    how_many_points = pyg_data.x.size(0)
    # we repeat the text_emb for how_many_points times
    text_emb = text_emb.expand(how_many_points, -1)
    # print("text_emb: ", text_emb.size())
    # print("pyg_data.text_emb: ", pyg_data.text_emb[0, 0])
    pyg_data.text_emb = text_emb # replace the text_emb in the pyg_data
    # print("pyg_data.text_emb: ", pyg_data.text_emb[0, 0])
    pos_states, samp_time = sampler.sample_one_batch(pyg_data)

    pred_batch_frames, gt_render_batch, text_des_list, furni_rel_list, pred_bbox_list, gt_bbox_list = render_frames([pos_states[-1]], 
                                                                                        pyg_data, batch["irregular_data"], 
                                                                                            render_cache_dir,
                                                                                        args.model3d_base_dir,
                                                                                        keep_source_file=False)
    # for batch_idx in range(len(pred_batch_frames[0])):
    # current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    # current_folder = os.path.join(args.output_dir, f"test_results_{data_idx}")
    # create_directory(current_folder, force=True)
    # save pos_states
    torch.save(pos_states, os.path.join(pos_states_dir, f"pos_states_{data_idx}.pt"))
    torch.save(
        {"pred_bbox": pred_bbox_list[0][0], "gt_bbox": gt_bbox_list[0][0]},
        os.path.join(bbox_dir, f"bbox_{data_idx}.pt")
    )
    # save batch frame
    # print("Your image: ", pred_batch_frames[0][batch_idx])
    mi_write_img(os.path.join(pred_images_dir, f"test_{data_idx}.png"), pred_batch_frames[0][0]) # Only for batch size = 1!
    mi_write_img(os.path.join(gt_images_dir, f"gt_{data_idx}.png"), gt_render_batch[0][0]) # Only for batch size = 1!
    # save text description
    with open(os.path.join(text_dir, f"test_{data_idx}.txt"), "w") as f:
        f.write(your_prompt)


