import rembg
import torch
import os
import numpy as np
import torch.nn.functional as F
import time
import trimesh
from omegaconf import OmegaConf
import torchvision
from tqdm import tqdm
import cv2
from rembg import remove

from utils.train_util import instantiate_from_config


class NormalTransfer:
    def __init__(self):
        self.identity_w2c = torch.tensor([
            [0.0, 0.0, 1.0, 0.0],
            [0.0, 1.0, 0.0, 0.0],
            [-1.0, 0.0, 0.0, 4.5]]).float()

    def look_at(self, camera_position, target_position, up_vector=np.array([0, 0, 1])):
        forward = camera_position - target_position
        forward = forward / np.linalg.norm(forward)

        right = np.cross(up_vector, forward)
        right = right / np.linalg.norm(right)

        up = np.cross(forward, right)

        rotation_matrix = np.array([right, up, forward]).T

        translation_matrix = np.eye(4)
        translation_matrix[:3, 3] = -camera_position

        rotation_homogeneous = np.eye(4)
        rotation_homogeneous[:3, :3] = rotation_matrix

        w2c = rotation_homogeneous @ translation_matrix
        return w2c

    def generate_target_pose(self, azimuths_deg, elevations_deg, radius=4.5):
        if isinstance(azimuths_deg, torch.Tensor):
            azimuths_deg = azimuths_deg.cpu().numpy()
        if isinstance(elevations_deg, torch.Tensor):
            elevations_deg = elevations_deg.cpu().numpy()
        azimuths = np.deg2rad(azimuths_deg)
        elevations = np.deg2rad(elevations_deg)

        x = radius * np.cos(azimuths) * np.cos(elevations)
        y = radius * np.sin(azimuths) * np.cos(elevations)
        z = radius * np.sin(elevations)
        camera_positions = np.stack([x, y, z], axis=-1)

        target_position = np.array([0, 0, 0])

        w2c_matrices = [self.look_at(cam_pos, target_position) for cam_pos in camera_positions]
        w2c_matrices = np.stack(w2c_matrices, axis=0)
        return w2c_matrices

    def convert_to_blender(self, pose):
        # Swap the y and z axes
        w2c_opengl = pose
        w2c_opengl[[1, 2], :] = w2c_opengl[[2, 1], :]

        # Invert the y axis
        w2c_opengl[1] *= -1
        R = w2c_opengl[:3, :3]
        t = w2c_opengl[:3, 3]

        cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
        R = R.T
        t = -R @ t
        R_world2cv = cam_rec @ R
        t_world2cv = cam_rec @ t

        RT = np.concatenate([R_world2cv, t_world2cv[:, None]], 1)
        return RT

    def worldNormal2camNormal(self, rot_w2c, normal_map_world):
        H, W, _ = normal_map_world.shape
        # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
        normal_map_world = normal_map_world[..., :3]
        # faster version
        normal_map_flat = normal_map_world.contiguous().view(-1, 3)

        normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float())

        # Reshape the transformed normal map back to its original shape
        normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape)

        return normal_map_camera

    def trans_normal(self, normal, RT_w2c, RT_w2c_target):
        """
        :param normal: (H,W,3), torch tensor, range [-1,1]
        :param RT_w2c: (4,4), torch tensor, world to camera
        :param RT_w2c_target: (4,4), torch tensor, world to camera
        :return: normal_target_cam: (H,W,3), torch tensor, range [-1,1]
        """
        relative_RT = torch.matmul(RT_w2c_target[:3, :3], torch.linalg.inv(RT_w2c[:3, :3]))
        normal_target_cam = self.worldNormal2camNormal(relative_RT[:3, :3], normal)

        return normal_target_cam

    def trans_local_2_global(self, normal_local, azimuths_deg, elevations_deg, radius=4.5, for_lotus=True):
        """
        :param normal_local: (B,H,W,3), torch tensor, range [-1,1]
        :param azimuths_deg: (B,), numpy array, range [0,360]
        :param elevations_deg: (B,), numpy array, range [-90,90]
        :param radius: float, default 4.5
        :return: global_normal: (B,H,W,3), torch tensor, range [-1,1]

        """
        # print(f"normal_local.shape:{normal_local.shape}")
        # print(f"azimuths_deg.shape:{azimuths_deg.shape}")
        # print(f"elevations_deg.shape:{elevations_deg.shape}")
        assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0]
        identity_w2c = self.identity_w2c

        # generate target pose
        target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius)
        target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float()
        global_normal = []

        # transform normal
        for i in range(normal_local.shape[0]):
            normal_local_i = normal_local[i]
            normal_zero123 = self.trans_normal(normal_local_i, target_w2c[i], identity_w2c)
            global_normal.append(normal_zero123)

        global_normal = torch.stack(global_normal, dim=0)
        if for_lotus:
            global_normal[..., 0] *= -1
        global_normal = global_normal / torch.norm(global_normal, dim=-1, keepdim=True)
        return global_normal

    def trans_global_2_local(self, normal_local, azimuths_deg, elevations_deg, radius=4.5):
        """
        :param normal_global: (B,H,W,3), torch tensor, range [-1,1]
        :param azimuths_deg: (B,), numpy array, range [0,360]
        :param elevations_deg: (B,), numpy array, range [-90,90]
        :param radius: float, default 4.5
        :return: local_normal: (B,H,W,3), torch tensor, range [-1,1]

        """
        print(f"normal_local.shape:{normal_local.shape}")
        print(f"azimuths_deg.shape:{azimuths_deg.shape}")
        print(f"elevations_deg.shape:{elevations_deg.shape}")
        assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0]
        identity_w2c = self.identity_w2c

        # generate target pose
        target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius)
        target_w2c = torch.from_numpy(np.stack([w2c for w2c in target_w2c])).float()
        local_normal = []

        # transform normal
        for i in range(normal_local.shape[0]):
            normal_local_i = normal_local[i]
            normal = self.trans_normal(normal_local_i, identity_w2c, target_w2c[i])
            local_normal.append(normal)

        local_normal = torch.stack(local_normal, dim=0)
        # global_normal[...,0] *= -1
        local_normal = local_normal / torch.norm(local_normal, dim=-1, keepdim=True)
        return local_normal


def pad_camera_extrinsics_4x4(extrinsics):
    if extrinsics.shape[-2] == 4:
        return extrinsics
    padding = torch.tensor([[0, 0, 0, 1]]).to(extrinsics)
    if extrinsics.ndim == 3:
        padding = padding.unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
    extrinsics = torch.cat([extrinsics, padding], dim=-2)
    return extrinsics

def center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
    """
    Create OpenGL camera extrinsics from camera locations and look-at position.

    camera_position: (M, 3) or (3,)
    look_at: (3)
    up_world: (3)
    return: (M, 3, 4) or (3, 4)
    """
    # by default, looking at the origin and world up is z-axis
    if look_at is None:
        look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
    if up_world is None:
        up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
    if camera_position.ndim == 2:
        look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
        up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)

    # OpenGL camera: z-backward, x-right, y-up
    z_axis = camera_position - look_at
    z_axis = F.normalize(z_axis, dim=-1).float()
    x_axis = torch.linalg.cross(up_world, z_axis, dim=-1)
    x_axis = F.normalize(x_axis, dim=-1).float()
    y_axis = torch.linalg.cross(z_axis, x_axis, dim=-1)
    y_axis = F.normalize(y_axis, dim=-1).float()

    extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
    extrinsics = pad_camera_extrinsics_4x4(extrinsics)
    return extrinsics

def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5):
    azimuths = np.deg2rad(azimuths)
    elevations = np.deg2rad(elevations)

    xs = radius * np.cos(elevations) * np.cos(azimuths)
    ys = radius * np.cos(elevations) * np.sin(azimuths)
    zs = radius * np.sin(elevations)

    cam_locations = np.stack([xs, ys, zs], axis=-1)
    cam_locations = torch.from_numpy(cam_locations).float()

    c2ws = center_looking_at_camera_pose(cam_locations)
    return c2ws

def FOV_to_intrinsics(fov, device='cpu'):
    """
    Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
    Note the intrinsics are returned as normalized by image size, rather than in pixel units.
    Assumes principal point is at image center.
    """
    focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5)
    intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
    return intrinsics

def get_custom_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0):
    """
    Get the input camera parameters.
    """
    azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
    # azimuths = np.array([270, 180, 90, 0]).astype(float)
    elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)

    c2ws = spherical_camera_pose(azimuths, elevations, radius)
    c2ws = c2ws.float().flatten(-2)

    Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)

    extrinsics = c2ws[:, :12]
    intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
    cameras = torch.cat([extrinsics, intrinsics], dim=-1)

    return cameras.unsqueeze(0).repeat(batch_size, 1, 1)

def get_custom_era3D_input_cameras(batch_size=1, radius=4.0, fov=30.0):
    """
    Get the input camera parameters.
    """
    azimuths = np.array([0, 45, 90, 180, 270, 315]).astype(float)
    # azimuths = np.array([270, 180, 90, 0]).astype(float)
    elevations = np.array([0, 0, 0, 0, 0, 0]).astype(float)

    c2ws = spherical_camera_pose(azimuths, elevations, radius)
    c2ws = c2ws.float().flatten(-2)

    Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2)

    extrinsics = c2ws[:, :12]
    intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1)
    cameras = torch.cat([extrinsics, intrinsics], dim=-1)

    return cameras.unsqueeze(0).repeat(batch_size, 1, 1)

def rotate_x(a, device=None):
    s, c = np.sin(a), np.cos(a)
    return torch.tensor([[1, 0, 0, 0],
                         [0, c,-s, 0],
                         [0, s, c, 0],
                         [0, 0, 0, 1]], dtype=torch.float32, device=device)

def rotate_y(a, device=None):
    s, c = np.sin(a), np.cos(a)
    return torch.tensor([[ c, 0, s, 0],
                         [ 0, 1, 0, 0],
                         [-s, 0, c, 0],
                         [ 0, 0, 0, 1]], dtype=torch.float32, device=device)

def rotate_z(a, device=None):
    s, c = np.sin(a), np.cos(a)
    return torch.tensor([[ c, -s, 0, 0],
                        [ s,  c, 0, 0],
                        [ 0,  0, 1, 0],
                        [ 0,  0, 0, 1]], dtype=torch.float32, device=device)

def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath):

    pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])
    facenp_fx3 = facenp_fx3[:, [2, 1, 0]]

    mesh = trimesh.Trimesh(
        vertices=pointnp_px3,
        faces=facenp_fx3,
        vertex_colors=colornp_px3,
    )
    mesh.export(fpath, 'obj')


def get_background(img_tensor):

    B, C, H, W = img_tensor.shape
    assert C == 3, "Input tensor must have 3 channels (RGB)."

    img_numpy = (img_tensor.permute(0, 2, 3, 1) * 255).byte().cpu().numpy()  # (B, H, W, C)

    masks = []
    for i in range(B):
        mask = remove(img_numpy[i], only_mask=True)

        mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]

        masks.append(mask_binary[..., None])

    masks = np.stack(masks, axis=0)

    mask_tensor = torch.from_numpy(masks).permute(0, 3, 1, 2).float() / 255.0
    # breakpoint()
    return mask_tensor

# ---------------------------------tool------------------------------------------

def load_models():
    # load reconstruction model
    print('==> Loading reconstruction model ...')
    recon_device = "cuda:0"
    recon_model_config = OmegaConf.load("config/DiMeR_default.yaml")
    recon_model = instantiate_from_config(recon_model_config.model_config)
    # load recon model checkpoint
    model_ckpt_path = ""
    state_dict = torch.load(model_ckpt_path, map_location='cpu')
    state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
    recon_model.load_state_dict(state_dict, strict=True)
    recon_model.to(recon_device)
    recon_model.eval()
    recon_model.init_flexicubes_geometry("cuda:0", fovy=50.0)

    # load texture model
    print('==> Loading texture model ...')
    texture_device = "cuda:0"
    texture_model_config = OmegaConf.load("config/DiMeR_default.yaml")
    texture_model = instantiate_from_config(texture_model_config.model_config)
    # load recon model checkpoint
    model_ckpt_path = ""
    state_dict = torch.load(model_ckpt_path, map_location='cpu')
    state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
    texture_model.load_state_dict(state_dict, strict=True)
    texture_model.to(texture_device)
    texture_model.eval()

    return recon_model, texture_model, recon_model_config, texture_model_config

@torch.no_grad()
def DiMeR_reconstruct(model, infer_config, texture_model, images, normals, multi_view_mask,
                      name='', export_texmap=False,
                      is_local=False, need_padding=True,
                      camera_radius=3.5,
                      save_dir='./tmp'):
    """
    images: Tensor, shape (N, c, h, w)
    normals: Tensor, shape (N, c, h, w)
    """

    mesh_path_idx = os.path.join(save_dir, f'{name}.obj')

    device = normals.device

    input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=camera_radius, fov=30).to(device)
    # input_cameras = get_custom_era3D_input_cameras(batch_size=1, radius=camera_radius, fov=30).to(device)

    if is_local:
        normal_transfer = NormalTransfer()
        global_normals = normal_transfer.trans_local_2_global(normals.cpu().permute(0, 2, 3, 1),
                                                              torch.tensor([0, 90, 180, 270]),
                                                              torch.tensor([5, 5, 5, 5]), radius=4.5,
                                                              for_lotus=True).to(device)
        global_normals = global_normals.permute(0, 3, 1, 2)
    else:
        global_normals = normals

    global_normals = global_normals * multi_view_mask.to(device) + (1 - multi_view_mask.to(device))
    if need_padding:
        global_normals = F.pad(global_normals, (50, 50, 50, 50), value=1.)
        global_normals = F.interpolate(global_normals, (512, 512), mode='bilinear', align_corners=False)
    global_normals = global_normals.unsqueeze(0).clamp(0.0, 1.0).to(device)

    print(f"{time.time()} ==> local normal to global normal done")

    images = images * multi_view_mask.to(device) + (1 - multi_view_mask.to(device))
    if need_padding:
        images = F.pad(images, (50, 50, 50, 50), value=1.)
        images = F.interpolate(images, (512, 512), mode='bilinear', align_corners=False)
    images = images.unsqueeze(0).clamp(0.0, 1.0).to(device)

    print(f"{time.time()} ==> Runing DiMeR geometry reconstruction ...")
    planes = model.forward_planes(global_normals, input_cameras)
    vertices, faces, _ = model.extract_mesh(
        planes,
        use_texture_map=export_texmap,
        **infer_config,
    )

    print(f"{time.time()} ==> Runing DiMeR texture reconstruction ...")
    vertices = torch.tensor(vertices, device=device)
    faces = torch.tensor(faces, device=device)
    texture_planes = texture_model.forward_planes(images, input_cameras)
    vertex_colors, _, _ = texture_model.synthesizer.get_texture_prediction(
        texture_planes, vertices.unsqueeze(0))
    vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
    vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]

    vertices = vertices.cpu().numpy()
    faces = faces.cpu().numpy()
    vertex_colors = vertex_colors.clamp(0, 1).squeeze(0).cpu().numpy()
    vertex_colors = vertex_colors * 255.0
    vertex_colors = vertex_colors.astype(np.uint8)
    save_obj(vertices, faces, vertex_colors, mesh_path_idx)
    print(f"Mesh saved to {mesh_path_idx}")
    return mesh_path_idx


if __name__ == '__main__':

    name_dir = "path/to/input_single"
    images_dir = "path/to/zero123plus_6view"
    normals_dir = "path/to/zero123plus_6view_StableNormal"
    output_dir = "path/to/results"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    else:
        print(f"Output directory {output_dir} already exists. Exiting.")
        exit(0)

    device = "cuda:0"

    names_ = sorted(os.listdir(name_dir))
    names = [name[:-4] for name in names_]

    recon_model, texture_model, recon_model_config, texture_model_config = load_models()
    rembg_session = rembg.new_session()

    for name in tqdm(names):
        images = []
        normals = []
        for index in range(6):
            images.append(torchvision.io.read_image(os.path.join(images_dir, f"{name}_{index}.png")).unsqueeze(0))
            normals.append(torchvision.io.read_image(os.path.join(normals_dir, f"{name}_{index}_normal.png")).unsqueeze(0))
        images = torch.cat(images, dim=0).float().to(device) / 255.0
        normals = torch.cat(normals, dim=0).float().to(device) / 255.0
        images = F.interpolate(images, (512, 512), mode='bilinear', align_corners=False)
        normals = F.interpolate(normals, (512, 512), mode='bilinear', align_corners=False)

        masks = get_background(normals).to(device)

        DiMeR_reconstruct(recon_model, recon_model_config, texture_model, images, normals, masks,
                          name=name, export_texmap=False,
                          is_local=True, need_padding=False,
                          camera_radius=5.0,  # 3.5
                          save_dir=output_dir)


