import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render, flashsplat_render
import torchvision
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
import json
from PIL import Image
import sys
from datetime import datetime


SIMILARITY_THRESHOLD = 0.9

def safe_state(silent):
    old_f = sys.stdout
    class F:
        def __init__(self, silent):
            self.silent = silent

        def write(self, x):
            if not self.silent:
                if x.endswith("\n"):
                    old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
                else:
                    old_f.write(x)

        def flush(self):
            old_f.flush()

def save_mask(mask, save_path):
    mask = (mask > 0.5).float()
    mask_np = mask.detach().cpu().numpy()
    mask_np = np.squeeze(mask_np)
    mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_img.save(save_path)

np.random.seed(42)

def render_set(model_path, name, views, gaussians, pipeline, background, scene_name, total_categories, stats_counts_path, device):
    render_path = os.path.join(model_path, name, f"ours_0/renders")
    gts_path = os.path.join(model_path, name, f"ours_0/gt")
    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)

    features = np.load(os.path.join("clip", "clip_output_features.npy"))
    labels = np.load(os.path.join("clip", "clip_output_labels.npy"))

    with open(os.path.join('assets', 'text_features.json'), 'r') as f:
        data_loaded = json.load(f)
    all_texts = list(data_loaded.keys())
    text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32)

    scene_texts = {
        "waldo_kitchen": ['Stainless steel pots', 'dark cup', 'refrigerator', 'frog cup', 'pot', 'spatula', 'plate', 'spoon', 'toaster', 'ottolenghi', 'plastic ladle', 'sink', 'ketchup', 'cabinet', 'red cup', 'pour-over vessel', 'knife', 'yellow desk'],
        "ramen": ['nori', 'sake cup', 'kamaboko', 'corn', 'spoon', 'egg', 'onion segments', 'plate', 'napkin', 'bowl', 'glass of water', 'hand', 'chopsticks', 'wavy noodles'],
        "figurines": ['jake', 'pirate hat', 'pikachu', 'rubber duck with hat', 'porcelain hand', 'red apple', 'tesla door handle', 'waldo', 'bag', 'toy cat statue', 'miffy', 'green apple', 'pumpkin', 'rubics cube', 'old camera', 'rubber duck with buoy', 'red toy chair', 'pink ice cream', 'spatula', 'green toy chair', 'toy elephant'],
        "teatime": ['sheep', 'yellow pouf', 'stuffed bear', 'coffee mug', 'tea in a glass', 'apple', 'coffee', 'hooves', 'bear nose', 'dall-e brand', 'plate', 'paper napkin', 'three cookies', 'bag of cookies']
    }
    target_text = scene_texts[scene_name]
    query_text_feats = torch.zeros(len(target_text), 512, device=device)
    for i, text in enumerate(target_text):
        feat = text_features[all_texts.index(text)].unsqueeze(0)
        query_text_feats[i] = feat

    features_torch = torch.from_numpy(features).to(torch.float32).to(device)
    features_torch = F.normalize(features_torch, dim=1)
    query_text_feats_norm = F.normalize(query_text_feats, dim=1)
    cosine_sim = torch.matmul(query_text_feats_norm, features_torch.T)

    for idxx, text_query in enumerate(target_text):
        scores_for_query = cosine_sim[idxx]
        matched_indices_for_query = torch.where(scores_for_query > SIMILARITY_THRESHOLD)[0].cpu().numpy()
        unique_labels_for_query = np.unique(labels[matched_indices_for_query])
        
        if len(unique_labels_for_query) == 0:
            continue
            
        print(f"Text '{text_query}' threshold-matched {len(unique_labels_for_query)} unique object(s) (sim>{SIMILARITY_THRESHOLD}): {unique_labels_for_query.tolist()}")
        print("Combining masks for joint rendering...")

        num_gaussians = gaussians.get_xyz.shape[0]
        combined_mask = torch.zeros(num_gaussians, dtype=torch.bool, device=device)

        for label in unique_labels_for_query:
            label = int(label)
            cur_label_dir = os.path.join(stats_counts_path, f"class_id_{label:03d}_total_categories_{total_categories:03d}_label.pth")
            
            if not os.path.exists(cur_label_dir):
                print(f"Warning: Missing pre-trained label file for class {label}. Skipping.", flush=True)
                continue


            unique_label_data = torch.load(cur_label_dir, map_location=device)
            current_obj_mask = (unique_label_data == 1)
            combined_mask = torch.logical_or(combined_mask, current_obj_mask)

        if not torch.any(combined_mask):
            print(f"Warning: No valid points found for text '{text_query}' after checking all matched labels.")
            continue



        class_dir = os.path.join(stats_counts_path, f"ans_{text_query.replace(' ', '_')}")
        merged_render_path = os.path.join(class_dir, "render")
        merged_gt_path = os.path.join(class_dir, "gt")
        merged_mask_path = os.path.join(class_dir, "mask")
        os.makedirs(merged_render_path, exist_ok=True)
        os.makedirs(merged_gt_path, exist_ok=True)
        os.makedirs(merged_mask_path, exist_ok=True)
        
        desc = f"Rendering combined group for '{text_query}'"
        for idx, view in enumerate(tqdm(views, desc=desc)):
            scene_gt_frames = {
                "waldo_kitchen": ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"],
                "ramen": ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"],
                "figurines": ["frame_00152", "frame_00195"],
                "teatime": ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]
            }
            if view.image_name not in scene_gt_frames[scene_name]:
                continue
            
            # Use the single 'combined_mask' for rendering
            render_pkg = flexirender(view, gaussians, pipeline, background, used_mask=combined_mask)
            
            rendering, render_alpha = render_pkg["render"], render_pkg["alpha"]
            gt = view.original_image[0:3, :, :]
            base_name = f"{view.image_name}.png"
            
            torchvision.utils.save_image(rendering, os.path.join(merged_render_path, base_name))
            torchvision.utils.save_image(gt, os.path.join(merged_gt_path, base_name))
            save_mask((render_alpha > 0.5), os.path.join(merged_mask_path, base_name))


def render_sets(dataset: ModelParams, pipeline: PipelineParams, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=args.iteration, shuffle=False)
        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        render_set(dataset.model_path, "text2obj_combined", scene.getTrainCameras(), gaussians, pipeline, background, 
                   args.scene_name, args.total_categories, args.stats_counts_path, device)


if __name__ == "__main__":
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--stats_counts_path", type=str, required=True)
    parser.add_argument("--total_categories", default=-1, type=int)
    parser.add_argument("--scene_name", type=str, required=True, choices=["waldo_kitchen", "ramen", "figurines", "teatime"])
    
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    safe_state(args.quiet)

    render_sets(model.extract(args), pipeline.extract(args), args)