import pandas as pd
import numpy as np
from scipy.interpolate import UnivariateSpline
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
import torch
from torch import Tensor
import os
from tqdm import tqdm
from einops import rearrange


def read_recordings(path: str, features: list[str]) -> pd.DataFrame:
    all_recordings_df = pd.read_csv(path, header=[0])

    # Map File_name to recording number
    recording_name_to_nr = {
        recording_name: recording_nr
        for recording_nr, recording_name in enumerate(all_recordings_df["File_name"].unique())
    }

    all_recordings_df["RECORDING"] = all_recordings_df["File_name"].map(recording_name_to_nr)

    return all_recordings_df[["RECORDING", "TRACK_ID", "FRAME"] + features]


def sliding_window(
    data: np.ndarray, config: dict, window_size: int, step: int
) -> list[np.ndarray]:
    """Applies sliding window to data of shape (n_trajectories, n_frames, 2)
    to create subsequences of given window size"""

    num_frames = data.shape[1]
    new_num_frames = num_frames - window_size + 1
    samples = []

    for i in tqdm(range(0, new_num_frames, step), desc="Sliding over data using window"):
        window = data[:, i : i + window_size]
        # Drop trajectories with too many NaNs
        nan_mask = (
            np.isnan(window).sum(axis=(1, 2))
            <= config["preprocess_nans_in_trajectory_threshold"] * window_size
        )
        if nan_mask.sum() == 0:
            continue

        samples.append(window[nan_mask])

    return samples


def calculate_local_density(positions: Tensor, config: dict) -> Tensor:
    """
    Compute local cell density using Gaussian kernel density estimation.

    Args:
        positions: Tensor of shape (frames, num_cells, 2) containing cell positions
        config: dict
    Returns:
        densities: Tensor of shape (frames, num_cells, 1) containing local density estimates
    """

    # Compute pairwise distances between all cells, Reshape positions for broadcasting
    pos_i = positions.unsqueeze(2)  # (frames, num_cells, 1, 2)
    pos_j = positions.unsqueeze(1)  # (frames, 1, num_cells, 2)

    # Compute squared distances
    squared_distances = torch.norm(pos_i - pos_j, dim=-1)  # (frames, num_cells, num_cells)

    # Apply Gaussian kernel
    kernel_values = torch.exp(-squared_distances / (2 * config["density_sigma"] ** 2))

    # Sum over all neighboring cells (excluding self-interaction)
    self_mask = torch.eye(positions.shape[1], device=positions.device)[None, :, :]
    kernel_values = kernel_values * (1 - self_mask)

    # Compute density by summing kernel values
    densities = torch.sum(kernel_values, dim=-1, keepdim=True)  # (frames, num_cells, 1)

    # Normalize densities
    densities = densities / torch.max(densities, dim=1, keepdim=True)[0]

    return densities.squeeze(-1)


def calculate_edge_features_per_frame(
    trajectories, velocities, edge_index, frame, config
) -> torch.Tensor:
    """Calculates edge features based on positions of source and target nodes
    args:
        trajectories: Tensor of shape (n_trajectories, n_frames, 2)
        velocities: Tensor of shape (n_trajectories, n_frames, 2)
        edge_index: Tensor of shape (2, n_edges)
        frame: int, current frame"""

    source, target = edge_index

    vel_i, vel_j = velocities[source, frame], velocities[target, frame]

    distances = torch.norm(trajectories[source, frame] - trajectories[target, frame], dim=1)
    # Standardize distances to [0, 1]
    distances = (distances - distances.min()) / (distances.max() - distances.min())

    # Use relative x and y position as edge attribute
    x_rel_pos = trajectories[source, frame, 0] - trajectories[target, frame, 0]
    y_rel_pos = trajectories[source, frame, 1] - trajectories[target, frame, 1]

    edge_attr_list = [torch.tensor([]) for _ in range(len(config["edge_features"]))]
    for feature, index in config["edge_features"].items():
        if feature == "dist":
            edge_attr_list[index] = distances
        elif feature == "delta_x":
            edge_attr_list[index] = x_rel_pos
        elif feature == "delta_y":
            edge_attr_list[index] = y_rel_pos
        elif feature == "delta_vel_x":
            edge_attr_list[index] = vel_i[:, 0] - vel_j[:, 0]
        elif feature == "delta_vel_y":
            edge_attr_list[index] = vel_i[:, 1] - vel_j[:, 1]
        elif feature == "relative_motion":
            direction = (trajectories[source, frame] - trajectories[target, frame]) / (
                distances.unsqueeze(1) + 1e-8
            )
            relative_motion = torch.sum((vel_i - vel_j) * direction, dim=1)
            edge_attr_list[index] = relative_motion
        elif feature == "approaching_speed":
            direction = (trajectories[source, frame] - trajectories[target, frame]) / (
                distances.unsqueeze(1) + 1e-8
            )
            projection_i = torch.sum(vel_i * direction, dim=1)
            projection_j = torch.sum(vel_j * direction, dim=1)

            approaching_speed = -(projection_i - projection_j) / 2
            edge_attr_list[index] = approaching_speed
        elif feature == "relative_density":
            trajs = trajectories[:, frame].unsqueeze(0)
            local_density = calculate_local_density(trajs, config)[0]
            relative_density = local_density[source] - local_density[target]
            edge_attr_list[index] = relative_density

    edge_attr = torch.stack(edge_attr_list, dim=1)

    return edge_attr


def calculate_edge_features(
    trajectories, node_features, edge_index, config, context=None
) -> torch.Tensor:
    """Calculates edge features of whole spatio temporal graph based
    on positions of source and target nodes
    args:
        trajectories: Tensor of shape (batch_size * n_frames * n_trajectories, x)
        node_features: Tensor of shape (batch_size * n_frames * n_trajectories, y)
        edge_index: Tensor of shape (2, n_edges)
        config: dict"""

    source, target = edge_index

    if source.shape[0] == 0 or target.shape[0] == 0:
        return torch.empty((0, len(config["edge_features"])), device=trajectories.device)
    # if trajectories.dim() == 3 or velocities.dim() == 3:
    #     trajectories = trajectories.unsqueeze(0)
    #     velocities = velocities.unsqueeze(0)
    vel_idx = [config["node_features"]["vel_x"], config["node_features"]["vel_y"]]
    vel_i = node_features[:, vel_idx][source]
    vel_j = node_features[:, vel_idx][target]
    trajs = trajectories

    distances = torch.norm(trajs[source] - trajs[target], dim=1)

    # Use relative x and y position as edge attribute
    x_rel_pos = trajs[source, 0] - trajs[target, 0]
    y_rel_pos = trajs[source, 1] - trajs[target, 1]

    # Use gaussian kernel
    exp_distances = torch.exp(-(distances**2) / (2 * config["dist_kernel"] ** 2))

    # Scale positions
    # pos_x = (trajs[:, 0] - config["x_range"][0]) / (
    #     config["x_range"][1] - config["x_range"][0]
    # )
    # pos_y = (trajs[:, 1] - config["y_range"][0]) / (
    #     config["y_range"][1] - config["y_range"][0]
    # )

    if (
        "delta_x_heading" in config["edge_features"]
        and "delta_y_heading" in config["edge_features"]
    ):
        norms = torch.norm(vel_i, dim=1).clip(min=1e-5).unsqueeze(1)
        heading = vel_i / norms  # (heading direction)
        thetas = torch.arctan2(heading[:, 1], heading[:, 0]) - torch.pi / 2
        rot_matrix = torch.zeros(
            vel_i.shape[0],
            2,
            2,
            dtype=vel_i.dtype,
            device=vel_i.device,
        )
        rot_matrix[:, 0, 0] = torch.cos(-thetas)
        rot_matrix[:, 0, 1] = -torch.sin(-thetas)
        rot_matrix[:, 1, 0] = torch.sin(-thetas)
        rot_matrix[:, 1, 1] = torch.cos(-thetas)
        rel_pos_reshaped = torch.stack([x_rel_pos, y_rel_pos], dim=1).unsqueeze(2)
        rotated_rel_pos = torch.bmm(rot_matrix, rel_pos_reshaped).squeeze(2)

    edge_attr = torch.empty(
        (distances.shape[0], len(config["edge_features"])),
        device=trajectories.device,
        dtype=trajectories.dtype,
    )
    for feature, index in config["edge_features"].items():
        if feature == "dist":
            edge_attr[:, index] = exp_distances
        if feature == "euclidean_dist":
            edge_attr[:, index] = distances
        elif feature == "delta_x":
            edge_attr[:, index] = x_rel_pos
        elif feature == "delta_y":
            edge_attr[:, index] = y_rel_pos
        elif feature == "delta_z":
            edge_attr[:, index] = trajs[source, 2] - trajs[target, 2]
        elif feature == "delta_vel_x":
            edge_attr[:, index] = vel_i[:, 0] - vel_j[:, 0]
        elif feature == "delta_vel_y":
            edge_attr[:, index] = vel_i[:, 1] - vel_j[:, 1]
        elif feature == "delta_vel_z":
            vel_z = node_features[:, config["node_features"]["vel_z"]]
            edge_attr[:, index] = vel_z[source] - vel_z[target]
        elif feature == "delta_x_heading":
            edge_attr[:, index] = rotated_rel_pos[:, 0]
        elif feature == "delta_y_heading":
            edge_attr[:, index] = rotated_rel_pos[:, 1]
        elif feature == "relative_turning_angle":
            theta = torch.atan2(vel_j[:, 1], vel_j[:, 0]) - torch.atan2(
                vel_i[:, 1], vel_i[:, 0]
            )
            edge_attr[:, index] = theta / 10
        elif feature == "relative_density":
            density_i = node_features[:, config["node_features"]["local_density"]][source]
            density_j = node_features[:, config["node_features"]["local_density"]][target]
            relative_density = density_i - density_j
            edge_attr[:, index] = relative_density
        elif feature == "charge":
            charge_i = node_features[:, config["node_features"]["charge"]][source]
            charge_j = node_features[:, config["node_features"]["charge"]][target]
            edge_attr[:, index] = torch.where(charge_i == charge_j, 1, -1)
        elif feature == "spring" and context is not None:
            edge_attr[:, index] = context

    return edge_attr


def construct_graphs(
    positions: np.ndarray,
    features: np.ndarray,
    recording_indices: np.ndarray,
    k: int,
    config: dict,
) -> list[list[Data]]:
    """Construct knn graphs from trajectories of shape (n_trajectories, n_frames, 2)
    based on recording indices"""

    graphs = [[] for _ in np.unique(recording_indices)]

    for idx in np.unique(recording_indices):
        # Get all trajectories from the same recording
        trajectories = torch.tensor(positions[recording_indices == idx], dtype=torch.float32)
        rec_features = torch.tensor(features[recording_indices == idx], dtype=torch.float32)

        for frame in range(trajectories.shape[1]):
            edge_indices = knn_graph(trajectories[:, frame], k=k)

            velocities = torch.diff(
                trajectories, dim=1, append=torch.zeros(trajectories.shape[0], 1, 2)
            )

            edge_attr = calculate_edge_features_per_frame(
                trajectories, velocities, edge_indices, frame, config
            )

            graph = Data(
                x=trajectories[:, frame],
                features=rec_features[:, frame],
                edge_index=edge_indices,
                edge_attr=edge_attr,
                pos=trajectories[:, frame, :2],
                time=torch.tensor([frame]),
            )

            graphs[idx].append(graph)

    return graphs


def preprocess_nba_file(path: str, name: str, config: dict) -> list[np.ndarray]:
    input_filename = f"{path}/{name}"
    df = pd.read_csv(
        input_filename, sep=" ", header=None, names=["frame", "agent_id", "x", "y", "group"]
    )
    df.sort_values(by=["frame", "agent_id"])

    df["team"] = -1

    player_data = df[df["group"] == "PLAYER"]
    player_ranks = player_data.groupby("frame").cumcount()
    df.loc[player_data.index, "team"] = (player_ranks >= 5).astype(int)
    df[["agent_id", "team"]] = df[["agent_id", "team"]] + 1
    df["group"] = df["group"].map({"PLAYER": 0, "BALL": 1})

    agent_rank = {agent_id: rank for rank, agent_id in enumerate(df["agent_id"].unique())}
    df["agent_id"] = df["agent_id"].map(agent_rank)

    T = df["frame"].nunique()
    data = rearrange(df.values, "(T A) D -> T A D", T=T)
    # data_dict = {
    #     "frame_id": data[..., 0],
    #     "agent_id": data[..., 1],
    #     "pos": data[..., 2:4],
    #     "group": data[..., 4],
    #     "team": data[..., 5],
    # }

    samples = []
    window, step = config["time_window"], config["preprocess_sliding_window_step"]
    for start in range(0, data.shape[0] - window + 1, step):
        end = start + window
        sample = data[start:end]
        samples.append(sample)

    return samples
