import numpy as np
import math
import random
from scipy.spatial.transform import Rotation as R
import torch
from torch import nn


def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P


def getWorld2View2_tensor(R, t, translate=torch.tensor([.0, .0, .0]), scale=1.0):
    Rt = torch.zeros((4, 4))
    Rt[:3, :3] = R.transpose(0,1)
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0

    C2W = torch.linalg.inv(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = torch.linalg.inv(C2W)
    return Rt.float()


def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))


def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))


class Camera(nn.Module):
    def __init__(self, c2w, FoVy, height, width,
                 trans=torch.tensor([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                 ):
        super(Camera, self).__init__()
        FoVx = focal2fov(fov2focal(FoVy, height), width)
        # FoVx = focal2fov(fov2focal(FoVy, width), height)

        R = c2w[:3, :3]
        T = c2w[:3, 3]

        self.R = R.float()
        self.T = T.float()
        self.FoVx = FoVx
        self.FoVy = FoVy
        self.image_height = height
        self.image_width = width

        try:
            self.data_device = torch.device(data_device)
        except Exception as e:
            print(e)
            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
            self.data_device = torch.device("cuda")

        self.zfar = 100.0
        self.znear = 0.01

        self.trans = trans.float()
        self.scale = scale

        self.world_view_transform = getWorld2View2_tensor(R, T).transpose(0, 1).float().cuda()
        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).float().cuda()
        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0).float()
        self.camera_center = self.world_view_transform.inverse()[3, :3].float()


trans_t = lambda t : torch.Tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1]]).float()

def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w


class RandomCameraCam:
    def __init__(self, opt):

        self.height = opt.height
        self.width = opt.width
        self.batch_size = opt.batch_size
        self.azimuth_range = opt.azimuth_range
        self.camera_distance_range = opt.camera_distance_range
        self.fovy_range = opt.fovy_range
        self.elevation_range = opt.elevation_range

    def collate(self):
        # 1.sample elevation angles
        if random.random() < 0.5:
            # sample elevation angles uniformly with a probability 0.5 (biased towards poles)
            elevation_deg = (
                    torch.rand(self.batch_size)
                    * (self.elevation_range[1] - self.elevation_range[0])
                    + self.elevation_range[0]
            )
            elevation = elevation_deg * math.pi / 180
        else:
            # otherwise sample uniformly on sphere
            elevation_range_percent = [
                (self.elevation_range[0] + 90.0) / 180.0,
                (self.elevation_range[1] + 90.0) / 180.0,
            ]
            # inverse transform sampling
            elevation = torch.asin(
                2
                * (
                        torch.rand(self.batch_size)
                        * (elevation_range_percent[1] - elevation_range_percent[0])
                        + elevation_range_percent[0]
                )
                - 1.0
            )
            elevation_deg = elevation / math.pi * 180.0

        # sample azimuth angles from a uniform distribution bounded by azimuth_range
        azimuth_deg = (
              torch.rand(self.batch_size) + torch.arange(self.batch_size)
          ) / self.batch_size * (
                  self.azimuth_range[1] - self.azimuth_range[0]
          ) + self.azimuth_range[
              0
          ]
        azimuth = azimuth_deg * math.pi / 180

        # sample distances from a uniform distribution bounded by distance_range
        camera_distances = (
            torch.rand(self.batch_size)
            * (self.camera_distance_range[1] - self.camera_distance_range[0])
            + self.camera_distance_range[0]
        )

        # sample fovs from a uniform distribution bounded by fov_range
        fovy_deg= (
            torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0])
            + self.fovy_range[0]
        )
        fovy = fovy_deg * math.pi / 180

        # calculate c2w_3dgs
        c2w_3dgs = []
        for id in range(self.batch_size):
            render_pose = pose_spherical(azimuth_deg[id] + 180.0, -elevation_deg[id],camera_distances[id])
            # print(azimuth_deg[id] , -elevation_deg[id], camera_distances[id]*2.0)
            # print(render_pose)

            matrix = torch.linalg.inv(render_pose)
            # R = -np.transpose(matrix[:3,:3])
            # R = -np.transpose(matrix[:3,:3])
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]
            c2w_single = torch.cat([R, T[:, None]], 1)
            c2w_single = torch.cat([c2w_single, torch.tensor([[0, 0, 0, 1]])], 0)
            # c2w_single = convert_camera_to_world_transform(c2w_single)
            c2w_3dgs.append(c2w_single)

        c2w_3dgs = torch.stack(c2w_3dgs, 0)

        return {
            "c2w_3dgs": c2w_3dgs,
            "elevation": elevation_deg,
            "azimuth": azimuth_deg,
            "camera_distances": camera_distances,
            "height": self.height,
            "width": self.width,
            "fovy": fovy,
        }

    def surroundAzimuth(self, initial_azimuth):
        # Fixed elevation angle
        elevation_deg = torch.full((self.batch_size,), 20)

        # Sampling azimuth angles from 0° to 360°, every 5°
        azimuth_deg = torch.linspace(initial_azimuth, initial_azimuth + 30 * 3, 4)

        # Sample distances from a uniform distribution bounded by distance_range
        camera_distances = (
                torch.rand(self.batch_size)
                * (self.camera_distance_range[1] - self.camera_distance_range[0])
                + self.camera_distance_range[0]
        )

        # Sample fovs from a uniform distribution bounded by fov_range
        fovy_deg = torch.full((self.batch_size,), 50)
        fovy = fovy_deg * math.pi / 180

        # calculate c2w_3dgs
        c2w_3dgs = []
        for id in range(self.batch_size):
            render_pose = pose_spherical(azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id])
            matrix = torch.linalg.inv(render_pose)
            # R = -np.transpose(matrix[:3,:3])
            # R = -np.transpose(matrix[:3,:3])
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]
            c2w_single = torch.cat([R, T[:, None]], 1)
            c2w_single = torch.cat([c2w_single, torch.tensor([[0, 0, 0, 1]])], 0)
            # c2w_single = convert_camera_to_world_transform(c2w_single)
            c2w_3dgs.append(c2w_single)

        c2w_3dgs = torch.stack(c2w_3dgs, 0)

        return {
            "c2w_3dgs": c2w_3dgs,
            "elevation": elevation_deg,
            "azimuth": azimuth_deg,
            "camera_distances": camera_distances,
            "height": self.height,
            "width": self.width,
            "fovy": fovy,
        }

    def surroundElevation(self, initial_elevation):
        # Fixed elevation angle
        elevation_deg = torch.linspace(initial_elevation, initial_elevation + 30 * 3, 4)

        azimuth_deg = torch.full((self.batch_size,), 0)

        # Sample distances from a uniform distribution bounded by distance_range
        camera_distances = (
                torch.rand(self.batch_size)
                * (self.camera_distance_range[1] - self.camera_distance_range[0])
                + self.camera_distance_range[0]
        )

        # Sample fovs from a uniform distribution bounded by fov_range
        fovy_deg = torch.full((self.batch_size,), 50)
        fovy = fovy_deg * math.pi / 180

        # calculate c2w_3dgs
        c2w_3dgs = []
        for id in range(self.batch_size):
            render_pose = pose_spherical(azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id])
            matrix = torch.linalg.inv(render_pose)
            # R = -np.transpose(matrix[:3,:3])
            # R = -np.transpose(matrix[:3,:3])
            R = -torch.transpose(matrix[:3, :3], 0, 1)
            R[:, 0] = -R[:, 0]
            T = -matrix[:3, 3]
            c2w_single = torch.cat([R, T[:, None]], 1)
            c2w_single = torch.cat([c2w_single, torch.tensor([[0, 0, 0, 1]])], 0)
            # c2w_single = convert_camera_to_world_transform(c2w_single)
            c2w_3dgs.append(c2w_single)

        c2w_3dgs = torch.stack(c2w_3dgs, 0)

        return {
            "c2w_3dgs": c2w_3dgs,
            "elevation": elevation_deg,
            "azimuth": azimuth_deg,
            "camera_distances": camera_distances,
            "height": self.height,
            "width": self.width,
            "fovy": fovy,
        }

    def input_view(self):
        elevation_deg = 30
        azimuth_deg = 0
        camera_distances = 3
        fovy_deg = 60
        fovy = fovy_deg * math.pi / 180

        # calculate c2w_3dgs
        render_pose = pose_spherical(azimuth_deg + 180.0, -elevation_deg, camera_distances)
        matrix = torch.linalg.inv(render_pose)
        R = -torch.transpose(matrix[:3, :3], 0, 1)
        R[:, 0] = -R[:, 0]
        T = -matrix[:3, 3]
        c2w_single = torch.cat([R, T[:, None]], 1)
        c2w_3dgs = torch.cat([c2w_single, torch.tensor([[0, 0, 0, 1]])], 0)

        return {
            "c2w_3dgs": c2w_3dgs,
            "elevation": elevation_deg,
            "azimuth": azimuth_deg,
            "camera_distances": camera_distances,
            "height": 256,
            "width": 256,
            "fovy": fovy,
        }