import os
import torch
import datetime
from argparse import ArgumentParser
from scene_pyg_loader import ScenePyGLoaderSeq, text_ball_collate_fn_pyg
from models.text_grad_gnn import AssembleModel_Room, GradientFieldSampler
from torch.utils.data import DataLoader
from torch_geometric.data import Batch
from utils import create_directory, get_model_args, get_sampler_param, get_text_emb_handler #, save_network
from text_emb_handler import SentenceTransformerHandler
from arrange_tools import arrange_room_mesh, render_transformed_obj, arrange_3d_to_2d, arrange_2d_to_3d, render_frames
from mitsuba_render_func import mi_write_img

parser = ArgumentParser()
parser.add_argument("--dataset_dir", type=str, default="my_dataset")
parser.add_argument("--text_cache_dir", type=str, default="text_cache") # cache folder for text embedding
parser.add_argument("--model_dir", type=str, default="model.pt") # model folder

# Sampling args
parser.add_argument("--sampler", type=str, default="PC")
parser.add_argument("--num_steps", type=int, default=500)
parser.add_argument("--t0", type=float, default=1.0)
parser.add_argument("--snr", type=float, default=0.16)

# Text embedding model
# parser.add_argument("--text_emb_model_id", type=str, default="paraphrase-MiniLM-L12-v2")

# Output directory
parser.add_argument("--model3d_base_dir", type=str, default="models")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--render_cache_dir", type=str, default="cache_3d_render")

args = parser.parse_args()


test_dataset_dir = os.path.join(args.dataset_dir, "test")
train_cache_dir = os.path.join(args.text_cache_dir, "train")
test_cache_dir = os.path.join(args.text_cache_dir, "test")

test_dataset = ScenePyGLoaderSeq(test_dataset_dir, test_cache_dir)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=text_ball_collate_fn_pyg)

loader_len = len(test_dataloader)
print("Number of data in the test dataset: ", loader_len)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# load language embedding model
# emb_handler = SentenceTransformerHandler(DEVICE, args.text_emb_model_id)

# Load the model
model_dict = torch.load(args.model_dir, map_location="cpu", weights_only=False) # we first place the model on the cpu

model_args = model_dict["model_args"]
print("Model args: ", model_args)
model = AssembleModel_Room(model_args).to(DEVICE)
model.load_state_dict(model_dict["model"])
model.eval()

# load text embedding model
print("Your text embedding model: ", model_args["text_emb_model_id"])
emb_handler = get_text_emb_handler(model_args["text_emb_model_id"], DEVICE)

def return_data(data):
    for i, data in enumerate(test_dataloader):
        if i == data_idx:
            return data

sampler_param = get_sampler_param(args)
sampler = GradientFieldSampler(sampler_param, model, DEVICE)

render_cache_dir = os.path.join(args.output_dir, args.render_cache_dir)
create_directory(render_cache_dir, force=True)

while True:
    # receive the user's input
    data_idx = input("Enter the index of the data you want to rearrange: ")
    try:
        data_idx = int(data_idx)
        if data_idx < 0 or data_idx >= loader_len:
            print("out of range. Please enter a valid index.")
            continue
    except ValueError:
        print("It must be an integer. Please enter a valid index.")
        continue
    print("Data index: ", data_idx)
    
    # print("text_emb: ", text_emb.size())

    batch = return_data(data_idx)

    pyg_data = Batch.from_data_list(batch["data"]).to(DEVICE)
    # print("pyg_data: ", pyg_data)

    print("Irregular data: ", batch["irregular_data"])

    your_prompt = input("Enter your prompt: ")
    print("Your prompt: ", your_prompt)

    text_emb = emb_handler.get_sentence_embedding([your_prompt])

    how_many_points = pyg_data.x.size(0)
    # we repeat the text_emb for how_many_points times
    text_emb = text_emb.expand(how_many_points, -1)
    # print("text_emb: ", text_emb.size())
    # print("pyg_data.text_emb: ", pyg_data.text_emb[0, 0])
    pyg_data.text_emb = text_emb # replace the text_emb in the pyg_data
    # print("pyg_data.text_emb: ", pyg_data.text_emb[0, 0])
    pos_states, samp_time = sampler.sample_one_batch(pyg_data)

    pred_batch_frames, gt_render_batch, text_des_list, furni_rel_list = render_frames([pos_states[-1]], 
                                                                                                        pyg_data, batch["irregular_data"], 
                                                                                                            render_cache_dir,
                                                                                                        args.model3d_base_dir)
    for batch_idx in range(len(pred_batch_frames[0])):
        current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        current_folder = os.path.join(args.output_dir, f"test_render_{data_idx}_{current_time}")
        create_directory(current_folder, force=True)
        # save batch frame
        # print("Your image: ", pred_batch_frames[0][batch_idx])
        mi_write_img(os.path.join(current_folder, "test_render.png"), pred_batch_frames[0][batch_idx])
        mi_write_img(os.path.join(current_folder, "gt_render.png"), gt_render_batch[0][batch_idx])
        # save text description
        with open(os.path.join(current_folder, "test_render.txt"), "w") as f:
            f.write(your_prompt)


