import os
import random
import numpy as np
import torch
import wandb
import imageio
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.data import Batch
from argparse import ArgumentParser
from utils import create_directory, get_model_args, get_sampler_param, get_bert_model_type #, save_network
from text_emb_handler import SentenceTransformerHandler, emb_all_txt_files, BertTextEmbHandler
from scene_pyg_loader import ScenePyGLoader, text_ball_collate_fn_pyg
from models.text_grad_gnn import AssembleModel_Room, GradientFieldSampler
from arrange_tools import arrange_room_mesh, render_transformed_obj, arrange_3d_to_2d, arrange_2d_to_3d, render_frames
from mitsuba_render_func import mi_write_img

def save_network(network, dir):
    save_dict = {}
    if isinstance(network, DataParallel):
        save_dict["model"] = network.module.state_dict()
        save_dict["model_args"] = network.module.model_args
    else:
        save_dict["model"] = network.state_dict()
        save_dict["model_args"] = network.model_args
    torch.save(save_dict, dir)

parser = ArgumentParser()

# General experiment settings
parser.add_argument("--exp_name", type=str, default="text_room")
parser.add_argument("--seed", type=int, default=100)
parser.add_argument("--log_dir", type=str, default="logs")
parser.add_argument("--dataset_dir", type=str, default="my_dataset")
parser.add_argument("--text_cache_dir", type=str, default="text_cache") # cache folder for text embedding
parser.add_argument("--print_every_iter", type=int, default=100)
parser.add_argument("--save_every_epoch", type=int, default=30)
parser.add_argument("--val_every_epoch", type=int, default=30)
parser.add_argument("--only_last_frame", action="store_true")

parser.add_argument("--label_guidance", action="store_true")
parser.add_argument("--obj_feat_len", type=int, default=256)
parser.add_argument("--text_feat_len", type=int, default=128)
parser.add_argument("--time_feat_len", type=int, default=128)
parser.add_argument("--mid_lay_input_len", type=int, default=128) # actually it is pos feat len
# parser.add_argument("--target_len", type=int, default=4) # actually it is pos input len
parser.add_argument("--learning_obj", type=str, default="3D")
parser.add_argument("--n_layers", type=int, default=3)
parser.add_argument("--sigma", type=float, default=25.0)

# Sampling args
parser.add_argument("--sampler", type=str, default="PC")
parser.add_argument("--num_steps", type=int, default=500)
parser.add_argument("--t0", type=float, default=1.0)
parser.add_argument("--snr", type=float, default=0.16)

# Training parameters

parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--val_batch_size", type=int, default=4)
parser.add_argument("--epochs", type=int, default=3000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--num_workers", type=int, default=4)


# Language model settings

parser.add_argument("--text_emb_model_id", type=str, default="paraphrase-MiniLM-L12-v2")
parser.add_argument("--ask_for_delete", action="store_true", help="Ask for delete the existed cache folder")

# 3D render options
parser.add_argument("--model3d_base_dir", type=str, default="models")
parser.add_argument("--render_cache_dir", type=str, default="cache_3d_render")

args = parser.parse_args()

# control randomness
if args.seed >= 0:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)


######### Global settings #########
MULTI_GPU = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.device_count() > 1:
    MULTI_GPU = True

if __name__ == "__main__":
    ########### Device Info ###########
    print(f"Device: {DEVICE}")
    print(f"Multi-GPU: {MULTI_GPU}")

    # init wandb
    wandb.init(project="text_room", name=args.exp_name, config=args)
    # prepare log directory
    create_flag = create_directory(os.path.join(args.log_dir, args.exp_name))
    if create_flag:
        print("Log directory is created.")
    else:
        print("Exiting...")
        exit()
    
    model_save_dir = os.path.join(os.path.join(args.log_dir, args.exp_name), "models")
    test_dir = os.path.join(os.path.join(args.log_dir, args.exp_name), "test")
    create_directory(model_save_dir)
    create_directory(test_dir)

    print(f"You are using model {args.text_emb_model_id}")

    if get_bert_model_type(args.text_emb_model_id) == "SenBert":
        emb_handler = SentenceTransformerHandler(DEVICE, args.text_emb_model_id)
    elif get_bert_model_type(args.text_emb_model_id) == "Bert":
        emb_handler = BertTextEmbHandler(DEVICE, args.text_emb_model_id)
    else:
        print("Not a recognized model type.")
        raise ValueError
    
    print("Start to embed all text files...")

    def emb_txt(dataset_dir, cache_dir, force=False):
        create_flag = create_directory(cache_dir, force=force)
        if create_flag:
            # create a model log file
            with open(os.path.join(cache_dir, "model_log.txt"), "w") as f:
                f.write(f"Text Embedding Model: {args.text_emb_model_id}")
            emb_all_txt_files(emb_handler, dataset_dir, cache_dir)
            print("Text embedding is done, new embeddings created.")
        else:
            print("You are using the existed text embedding cache.")

    train_dataset_dir = os.path.join(args.dataset_dir, "train")
    test_dataset_dir = os.path.join(args.dataset_dir, "test")
    train_cache_dir = os.path.join(args.text_cache_dir, "train")
    test_cache_dir = os.path.join(args.text_cache_dir, "test")

    emb_txt(train_dataset_dir, train_cache_dir, force=(not args.ask_for_delete))
    emb_txt(test_dataset_dir, test_cache_dir, force=(not args.ask_for_delete))

    # Load the dataset
    train_dataset = ScenePyGLoader(train_dataset_dir, train_cache_dir)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=text_ball_collate_fn_pyg, num_workers=args.num_workers) # Remember change back to True!

    test_dataset = ScenePyGLoader(test_dataset_dir, test_cache_dir)
    

    print("Label length: ", train_dataset.label_len)

    # Model
    if args.label_guidance:
        model_args = get_model_args(args, train_dataset.label_len)
    else:
        model_args = get_model_args(args, -1)
    
    model = AssembleModel_Room(model_args).to(DEVICE)
    if MULTI_GPU:
        print("Using multi GPUs")
        model = DataParallel(model)
    else:
        print("Using single GPU")

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    iteration = 0

    render_cache_dir = os.path.join(test_dir, args.render_cache_dir)
    create_directory(render_cache_dir, force=True)

    for epoch in range(1, args.epochs + 1):
        model.train()
        # img_idx = 0
        for _, batch in enumerate(train_dataloader):
            # print(batch["data"][0].furni_single_pc_bbox)
            # exit()
            if not MULTI_GPU:
                pyg_data = Batch.from_data_list(batch["data"]).to(DEVICE)
            else:
                pyg_data = batch["data"] # it will automatically move to the device when using DataParallel
            if args.learning_obj == "2D":
                pyg_data.y = arrange_3d_to_2d(pyg_data.y) # only use x and z
            local_score_loss = model(pyg_data)
            global_loss = model.get_global_loss(local_score_loss)
            optimizer.zero_grad()
            global_loss.backward()
            optimizer.step()
            iteration += 1
            if iteration % args.print_every_iter == 0:
                tqdm.write(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {global_loss.item()}")
            wandb.log({"loss": global_loss.item()})

        if epoch % args.save_every_epoch == 0:
            save_path = os.path.join(model_save_dir, f"model_{epoch}.pt")
            save_network(model, save_path)

        if epoch % args.val_every_epoch == 0:
            model.eval()
            test_dataloader = DataLoader(test_dataset, batch_size=args.val_batch_size, shuffle=True, collate_fn=text_ball_collate_fn_pyg)
            # create image/video save folder
            epoch_test_folder = os.path.join(test_dir, f"epoch_{epoch}")
            create_directory(epoch_test_folder, force=True)
            save_idx = 0
            for _, batch in enumerate(test_dataloader):
                if not MULTI_GPU:
                    pyg_data = Batch.from_data_list(batch["data"]).to(DEVICE)
                else:
                    pyg_data = batch["data"]
                # Sampler
                sampler_param = get_sampler_param(args)
                sampler = GradientFieldSampler(sampler_param, model, DEVICE)
                pos_states, samp_time = sampler.sample_one_batch(pyg_data)
                if args.only_last_frame:
                    # save one frame as image
                    if args.learning_obj == "2D":
                        # one_batch_frame = render_frames([pos_states[-1]], pyg_data.batch)[0]
                        raise NotImplementedError
                    elif args.learning_obj == "3D":
                        pred_batch_frames, gt_render_batch, text_des_list, furni_rel_list, pred_bbox_list, gt_bbox_list = render_frames([pos_states[-1]], 
                                                                                                        pyg_data, batch["irregular_data"], 
                                                                                                            render_cache_dir,
                                                                                                        args.model3d_base_dir)
                        for batch_idx in range(len(pred_batch_frames[0])):
                            # save batch frame
                            # print("Your image: ", pred_batch_frames[0][batch_idx])
                            mi_write_img(os.path.join(epoch_test_folder, f"test_render_{save_idx}.png"), pred_batch_frames[0][batch_idx])
                            mi_write_img(os.path.join(epoch_test_folder, f"gt_render_{save_idx}.png"), gt_render_batch[0][batch_idx])
                            # save text description
                            with open(os.path.join(epoch_test_folder, f"test_render_{save_idx}.txt"), "w") as f:
                                f.write(text_des_list[0][batch_idx])
                            # save furniture relationship
                            with open(os.path.join(epoch_test_folder, f"test_render_{save_idx}_furni_rel.txt"), "w") as f:
                                f.write(furni_rel_list[0][batch_idx])
                            pred_bbox_path = os.path.join(epoch_test_folder, f"test_pred_{save_idx}_bbox.npy")
                            gt_bbox_path = os.path.join(epoch_test_folder, f"test_gt_{save_idx}_bbox.npy")
                            pred_bbox = np.array(pred_bbox_list[0][batch_idx])
                            gt_bbox = np.array(gt_bbox_list[0][batch_idx])
                            np.save(pred_bbox_path, pred_bbox)
                            np.save(gt_bbox_path, gt_bbox)

                            save_idx += 1
                else:
                    # we save videos
                    raise NotImplementedError
                break # currently only sample one batch

    # save the final model
    save_path = os.path.join(model_save_dir, f"model_final.pt")
    save_network(model, save_path)
