import numpy as np
import torch


def random_rotate_z(pc):
    # random roate around z axis
    theta = np.random.uniform(0, 2*np.pi)
    R = np.array([[np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]])
    return np.matmul(pc, R)

def normalize_pc(pc):
    # normalize pc to [-1, 1]
    pc = pc - np.mean(pc, axis=0)
    if np.max(np.linalg.norm(pc, axis=1)) < 1e-6:
        pc = np.zeros_like(pc)
    else:
        pc = pc / np.max(np.linalg.norm(pc, axis=1))
    return pc

def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
    ''' batch_pc: BxNx3 '''
    for b in range(batch_pc.shape[0]):
        dropout_ratio =  np.random.random()*max_dropout_ratio # 0~0.875
        drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
        if len(drop_idx)>0:
            batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
    return batch_pc

def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
    """ Randomly scale the point cloud. Scale is per point cloud.
        Input:
            BxNx3 array, original batch of point clouds
        Return:
            BxNx3 array, scaled batch of point clouds
    """
    B, N, C = batch_data.shape
    scales = np.random.uniform(scale_low, scale_high, B)
    for batch_index in range(B):
        batch_data[batch_index,:,:] *= scales[batch_index]
    return batch_data

def shift_point_cloud(batch_data, shift_range=0.1):
    """ Randomly shift point cloud. Shift is per point cloud.
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, shifted batch of point clouds
    """
    B, N, C = batch_data.shape
    shifts = np.random.uniform(-shift_range, shift_range, (B,3))
    for batch_index in range(B):
        batch_data[batch_index,:,:] += shifts[batch_index,:]
    return batch_data

def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
    """ Randomly perturb the point clouds by small rotations
        Input:
          BxNx3 array, original batch of point clouds
        Return:
          BxNx3 array, rotated batch of point clouds
    """
    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
    for k in range(batch_data.shape[0]):
        angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
        Rx = np.array([[1,0,0],
                       [0,np.cos(angles[0]),-np.sin(angles[0])],
                       [0,np.sin(angles[0]),np.cos(angles[0])]])
        Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
                       [0,1,0],
                       [-np.sin(angles[1]),0,np.cos(angles[1])]])
        Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
                       [np.sin(angles[2]),np.cos(angles[2]),0],
                       [0,0,1]])
        R = np.dot(Rz, np.dot(Ry,Rx))
        shape_pc = batch_data[k, ...]
        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
    return rotated_data

def augment_pc(data):
    data = random_point_dropout(data[None, ...])
    data = random_scale_point_cloud(data)
    data = shift_point_cloud(data)
    data = rotate_perturbation_point_cloud(data)
    data = data.squeeze()
    return data


# --- Utility: Subsampling Method ---
def get_subsampling_method(method="random"):
    def random_subsample(pc, sample_size):
        # pc is a torch tensor of shape (6, N)
        N = pc.shape[1]
        if sample_size >= N:
            return pc
        indices = torch.randperm(N)[:sample_size]
        return pc[:, indices]
    return random_subsample


# --- Scene Combination and Caption Refinement ---
def build_scene_from_point_clouds(tokenizer, llm, sampling_params, point_clouds, captions, out_size, min_size_per_sample, subsampling_method="random", return_pos=False):
    """
    Builds a 3D scene by arranging point clouds with spatial relations and refines the combined caption.
    
    Args:
        point_clouds (list of torch.Tensor): List of point clouds (shape (6, N)).
        captions (list of str): List of object captions.
        out_size (int): Target total number of points in the final scene.
        min_size_per_sample (int): Minimum number of points to sample per object.
        subsampling_method (str): Subsampling method (default "random").
        return_pos (bool): Whether to return placement info.
    
    Returns:
        tuple: (numpy.ndarray, str) Combined scene point cloud (shape (6, total_points)) and refined caption.
    """
    if not point_clouds:
        raise ValueError("The list of point clouds must not be empty.")
    if len(point_clouds) != len(captions):
        raise ValueError("The number of point clouds must match the number of captions.")

    assert out_size // len(point_clouds) >= min_size_per_sample, "Too many objects for the given out size and minimum points per object"

    sample_size = out_size // len(point_clouds)
    subsampling = get_subsampling_method(subsampling_method)
    last_sample_size = out_size - sample_size * len(point_clouds) + sample_size

    # Process each point cloud: subsample and apply a random rotation around the vertical (z) axis.
    for i, pc in enumerate(point_clouds):
        current_sample_size = sample_size if i < len(point_clouds) - 1 else last_sample_size
        point_clouds[i] = subsampling(pc, current_sample_size)
        # Random rotation around vertical axis (z)
        angle = torch.rand(1) * 2 * np.pi
        rotation_matrix = torch.tensor([
            [torch.cos(angle), torch.sin(angle), 0],
            [-torch.sin(angle), torch.cos(angle), 0],
            [0, 0, 1]
        ]).float()
        # Apply rotation to the xyz coordinates (rows 0:3)
        point_clouds[i][:3] = torch.mm(rotation_matrix, point_clouds[i][:3])

        # Apply small rotations to the other two axes (maximum 15 degrees).
        angle = torch.rand(1) * 0.2618  # 15 degrees in radians
        rotation_matrix = torch.tensor([
            [1, 0, 0],
            [0, torch.cos(angle), torch.sin(angle)],
            [0, -torch.sin(angle), torch.cos(angle)]
        ]).float()
        point_clouds[i][:3] = torch.mm(rotation_matrix, point_clouds[i][:3])
        
        angle = torch.rand(1) * 0.2618  # 15 degrees in radians
        rotation_matrix = torch.tensor([
            [torch.cos(angle), 0, torch.sin(angle)],
            [0, 1, 0],
            [-torch.sin(angle), 0, torch.cos(angle)]
        ]).float()
        point_clouds[i][:3] = torch.mm(rotation_matrix, point_clouds[i][:3])


    # Convert point clouds from torch.Tensor to numpy.ndarray.
    point_clouds = [p.detach().clone() for p in point_clouds]
    point_clouds = [p.numpy() for p in point_clouds]

    # Initialize the scene with the first point cloud and start building the raw caption.
    scene = point_clouds[0].copy()
    combined_caption = captions[0].strip() + "."

    vertical_offset = np.array([0.0, 0.0, 1.0])
    delta_offset = 0.2  # Fixed extra offset.
    noise_std = 0.001    # Noise standard deviation.

    last_position = np.array([0.0, 0.0, 0.0])
    placements = []

    # Process each remaining point cloud.
    for i, (point_cloud, caption) in enumerate(zip(point_clouds[1:], captions[1:]), start=1):
        if np.random.rand() > 0.5:
            # "Next to" relation: sample a random horizontal unit vector.
            theta = np.random.uniform(0, 2 * np.pi)
            d = np.array([np.cos(theta), np.sin(theta)])
            d3 = np.array([d[0], d[1], 0])
            proj_prev = np.dot(point_clouds[i - 1][:2].T, d)  # x,y coordinates of previous object.
            max_proj_prev = np.max(proj_prev)
            proj_current = np.dot(point_cloud[:2].T, d)
            min_proj_current = np.min(proj_current)
            displacement_magnitude = max_proj_prev - min_proj_current
            displacement = (displacement_magnitude + delta_offset) * d3
            displacement += np.random.normal(0, noise_std, size=3)
            last_position += displacement
            placement = "next to"
        else:
            # "Over" relation: place current object above the previous one.
            gap = np.random.uniform(0.2, 0.5)
            last_position += vertical_offset * (np.max(point_clouds[i - 1][2]) - np.min(point_clouds[i - 1][2]) - np.min(point_cloud[2]) + gap)
            placement = "over"

        placements.append(placement)
        point_cloud[:3] = point_cloud[:3] + last_position[:, None]
        scene = np.concatenate([scene, point_cloud], axis=1)

        combined_caption += f" {placement.capitalize()} it, {caption[0].lower() + caption[1:]}."


    print(f'Caption before Refinement: {combined_caption}')
    
    # --- Caption Refinement using Qwen2.5-Instruct via vLLM ---
    messages = [
        {"role": "system", "content": (
            "You are an advanced language model specializing in natural language refinement. "
            "Your task is to transform a raw combined caption into a fluent and well-structured description. "
            "The raw caption consists of multiple object descriptions connected by spatial relations such as 'next to', 'over', or 'under'. "
            "Spatial relations connect an object to the previous one."
            "Your goal is to ensure proper grammar, capitalization, and punctuation while making the text more readable and natural. "
            "If the sentence is too long or unnatural, split it into multiple fluent sentences while preserving the original meaning and spatial relationships. "
            "The final output should feel like a human-written description, maintaining clarity and coherence."
        )},
        {"role": "user", "content": f"Raw Caption: {combined_caption}\nRefined Caption:"}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    outputs = llm.generate([text], sampling_params)
    refined_output = outputs[0].outputs[0].text
    combined_caption = refined_output.strip()

    # Normalize the scene's xyz coordinates.
    scene[:3] -= scene[:3].min()
    scene[:3] /= scene[:3].max()

    if return_pos:
        return scene, combined_caption, placements
    return scene, combined_caption
