import os

import numpy as np

from utils.sh_utils import sh2rgb

from .functions import fetch_ply, get_nerf_pp_norm, store_ply
from .wild_smoke_reader import read_cameras_from_transforms_wild_smoke
from .scene_info import SceneInfo
import torch

def read_scene_wild_smoke(
    data_path,
    model_path,
    white_background,
    eval,
    extension=".png",
    start_time=50,
    duration=50,
    time_step=1,
    max_timestamp=1.0,
    gray_image=False,
    no_init_pcd=False,
    source_init=False,
    init_region_type="large",
    img_offset=False,
    init_num_pts_per_time=1000,
    init_trbf_c_fix=False,
    init_color_fix_value: float = None,
    train_json='',
    test_json='',
    *args,
    **kwargs,
):
    print("Reading Training ...")
    # train_json = f"transforms_train_wild_smoke_bg.json"
    # train_json = "transforms_train_wild_smoke.json"
    # train_json = train_json
    train_cam_infos = read_cameras_from_transforms_wild_smoke(
        data_path,
        train_json,
        white_background,
        start_time,
        duration,
        time_step,
        max_timestamp,
        gray_image,
        img_offset,
    )

    print("Reading Test ...")
    # test_json = "transforms_test_wild_smoke_bg.json"
    # test_json = test_json
    test_cam_infos = read_cameras_from_transforms_wild_smoke(
        data_path,
        test_json,
        white_background,
        start_time,
        duration,
        time_step,
        max_timestamp,
        gray_image,
        img_offset,
    )

    nerf_normalization = get_nerf_pp_norm(train_cam_infos)

    total_ply_path = os.path.join(model_path, "initial_points3d_total.ply")
    if os.path.exists(total_ply_path):
        os.remove(total_ply_path)

    img_channel = 1 if gray_image else 3

    fg_pts = torch.load(train_cam_infos[0].fg_pts) if train_cam_infos[0].fg_pts is not None else None
    bg_pts = torch.load(train_cam_infos[0].bg_pts) if train_cam_infos[0].bg_pts is not None else None
    if fg_pts is not None:
        # print(f"frame_name {frame_name} timestamp {timestamp} camera uid {uid}")
        device, dtype = fg_pts.device, fg_pts.dtype
        F = torch.tensor([[-1, 0, 0, 0],
                        [0, -1, 0, 0],
                        [0, 0,-1, 0],
                        [0, 0, 0, 1]], dtype=dtype, device=device)
        pts_h = torch.cat([fg_pts, torch.ones((fg_pts.shape[0], 1), dtype=dtype, device=device)], dim=1)
        pts_fg = (pts_h @ F.T)[:, :3]

    if init_region_type == "synthetic":
        radius_max = 0.028
        x_mid = 0.0 
        y_min = 0.0 
        y_max = 0.08 
        z_mid = -0.2879 
        delta = 0.00625
    elif no_init_pcd: 
        radius_max = 0.15
        x_mid = 0.0
        y_min = -0.2
        y_max = 0.2
        z_mid = -0.35
        delta = 0.03
    else:
        radius_max = 100
        x_mid = 0.0 
        y_min = -100
        y_max = 100 
        z_mid = 0 
        
    # import pdb; pdb.set_trace()
    if no_init_pcd:
        
        pcd = None
        print("No init pcd")
        x_range = np.arange(x_mid - radius_max, x_mid + radius_max + delta, delta)
        y_range = np.arange(y_min, y_max, delta)
        z_range = np.arange(z_mid - radius_max, z_mid + radius_max + delta, delta)

        points = []

        for x in x_range:
            for y in y_range:
                for z in z_range:
                    if (x - x_mid) ** 2 + (z - z_mid) ** 2 <= radius_max**2:
                        points.append([x, y, z])
        pts_new_fg = torch.tensor(points, dtype=torch.float32) 

    else:
        assert (fg_pts is not None), "No point cloud found in the training cameras!"
        x = pts_fg[:, 0]
        y = pts_fg[:, 1]
        z = pts_fg[:, 2]

        distance_squared = (x - x_mid)**2 + (z - z_mid)**2

        mask = (distance_squared < radius_max**2) & (y >= y_min) & (y <= y_max)
        pts_new_fg = pts_fg[mask]
    
    
    scene_info = SceneInfo(
        point_cloud=pts_new_fg,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=total_ply_path,
    )
    return scene_info


def read_scene_wild_smoke_eval(
    data_path,
    model_path,
    white_background,
    eval,
    start_time=50,
    duration=50,
    time_step=1,
    max_timestamp=1.0,
    gray_image=False,
    img_offset=False,
    train_json='',
    test_json='',
    *args,
    **kwargs,
):

    print("Reading Test Transforms...")
    # test_json = "transforms_test_wild_smoke_bg.json"
    # test_json = "transforms_test_wild_smoke.json"
    test_cam_infos = read_cameras_from_transforms_wild_smoke(
        data_path,
        test_json,
        white_background,
        extension,
        start_time,
        duration,
        time_step,
        max_timestamp,
        gray_image,
        img_offset,
    )

    nerf_normalization = get_nerf_pp_norm(test_cam_infos)

    # total_ply_path = os.path.join(model_path, "initial_points3d_total.ply")
    # pcd = fetch_ply(total_ply_path, gray_image)

    # assert pcd is not None, "Point cloud could not be loaded!"

    fg_pts = torch.load(train_cam_infos.fg_pts) if train_cam_infos.fg_pts is not None else None
    bg_pts = torch.load(train_cam_infos.bg_pts) if train_cam_infos.bg_pts is not None else None
   
    scene_info = SceneInfo(
        point_cloud=fg_pts.cuda(),
        train_cameras=test_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=total_ply_path,
    )
    return scene_info
