import pandas as pd
import numpy as np
from tqdm import tqdm
import math


def circdiff(circular_1, circular_2):
    res = np.arctan2(np.sin(circular_1-circular_2), np.cos(circular_1-circular_2))
    return abs(res)

############################### Read data ###################################
def read_test_data(datafile, dataset="ATC", x_min=-60.0, x_max=80.0, y_min=-40.0, y_max=20.0):
    if dataset == "ATC":
        if isinstance(datafile, list):
            dfs = []
            for f in datafile:
                df = pd.read_csv(f, header=None, names=["time", "person_id", "x", "y", "speed", "motion_angle"])
                df['motion_angle'] = np.mod(df['motion_angle'], 2 * np.pi)
                df = df[['time', 'x', 'y', 'speed', 'motion_angle']]
                dfs.append(df)
            test_data = pd.concat(dfs, ignore_index=True)        
        else:
            test_data = pd.read_csv(datafile, header=None)
            test_data.columns = ["time", "person_id", "x", "y", "speed", "motion_angle"]
            test_data['motion_angle'] = np.mod(test_data['motion_angle'], 2 * np.pi)
            test_data = test_data[['time', 'x', 'y', 'speed', 'motion_angle']]

    test_data = test_data[(test_data["x"] >= x_min) & (test_data["x"] <= x_max) &
                          (test_data["y"] >= y_min) & (test_data["y"] <= y_max)]

    return test_data


def read_test_data_with_hour(hour, datafile, dataset="ATC", x_min=-60.0, x_max=80.0, y_min=-40.0, y_max=20.0):
    if dataset == "ATC":
        if isinstance(datafile, list):
            dfs = []
            for f in datafile:
                df = pd.read_csv(f, header=None, names=["time", "person_id", "x", "y", "speed", "motion_angle"])
                df['motion_angle'] = np.mod(df['motion_angle'], 2 * np.pi)
                df = df[['time', 'x', 'y', 'speed', 'motion_angle']]
                dfs.append(df)
            test_data = pd.concat(dfs, ignore_index=True)        
        else:
            test_data = pd.read_csv(datafile, header=None)
            test_data.columns = ["time", "person_id", "x", "y", "speed", "motion_angle"]
            test_data['motion_angle'] = np.mod(test_data['motion_angle'], 2 * np.pi)
            test_data = test_data[['time', 'x', 'y', 'speed', 'motion_angle']]

    test_data = test_data[(test_data["x"] >= x_min) & (test_data["x"] <= x_max) &
                          (test_data["y"] >= y_min) & (test_data["y"] <= y_max)]
    
    ts_utc = pd.to_datetime(test_data["time"], unit="s", utc=True)
    test_data["ts_jst"] = ts_utc.dt.tz_convert("Asia/Tokyo")
    test_data["jst_hour"] = test_data["ts_jst"].dt.hour

    mask = test_data["jst_hour"] == hour
    sub = test_data.loc[mask, ["time", "x", "y", "speed", "motion_angle"]]
    test_data = sub.copy()

    return test_data


def round_to_nearest_half(x):
    return math.floor(x * 2) / 2


# which is generated by NIR
def read_NIR_map_data(datafile):
    MoD = pd.read_csv(datafile)
# x,y,mean_speed,mean_motion_angle,var_speed,var_motion_angle,coef,weight
    MoD["cov1"] = MoD["var_speed"]
    MoD["cov2"] = np.sqrt(MoD["var_speed"] * MoD["var_motion_angle"]) * MoD["coef"]
    MoD["cov3"] = MoD["cov2"]
    MoD["cov4"] = MoD["var_motion_angle"]
    MoD["weight"] = MoD["weight"]
    MoD["motion_ratio"] = 1.
    MoD['var_motion_angle'] = np.mod(MoD['var_motion_angle'], 2 * np.pi)
    MoD = MoD[['x', 'y', 'mean_speed', 'mean_motion_angle', 'cov1', 'cov2', 'cov3', 'cov4', 'weight', 'motion_ratio']]
    MoD.columns = ["x", "y", "speed", "motion_angle", "cov1", "cov2", "cov3", "cov4", "weight", "motion_ratio"]

    return MoD
#############################################################################


############################### Check Cov ###################################
def is_valid_covariance(cov):
    # Check if the covariance matrix is valid (non-zero and positive definite)
    try:
        np.linalg.cholesky(cov)
        return True
    except np.linalg.LinAlgError:
        return False
    
    
def is_too_narrow(cov, threshold=1e-6):
    determinant = np.linalg.det(cov)
    return determinant < threshold


def regularize_covariance(cov, epsilon=1e-6):
    cov += epsilon * np.eye(cov.shape[0])
    return cov
#############################################################################


############################### GMM #########################################
def circ_diff_signed(a, b):
    pi = np.pi
    return (a - b + pi) % (2 * pi) - pi   # in [-pi, pi]


#### align with DL GMM nll calculation
def nll_of_point(gmm_components, point):
    prob_total = 0
    
    v = float(point['speed'])
    a = float(point['motion_angle'])
    
    mu_s = gmm_components['speed'].to_numpy()        # (K,)
    mu_a = gmm_components['motion_angle'].to_numpy()    # (K,)
    w    = gmm_components['weight'].to_numpy()          # (K,)

    c11  = gmm_components['cov1'].to_numpy()            # (K,)
    c12  = gmm_components['cov2'].to_numpy()
    c21  = gmm_components['cov3'].to_numpy()
    c22  = gmm_components['cov4'].to_numpy()
    
    wraps = np.array([a - 2*np.pi, a, a + 2*np.pi], dtype=np.float64)   # (3,)
    for aw in wraps:
        ds = v - mu_s                          # (K,)
        da = circ_diff_signed(aw, mu_a)    # (K,)
        det = c11 * c22 - c12 * c21                      # (K,)
        det = np.clip(det, 1e-12, None)   
        
        inv00 =  c22 / det
        inv01 = -c12 / det
        inv10 = -c21 / det
        inv11 =  c11 / det
        
        maha = ds*(inv00*ds + inv01*da) + da*(inv10*ds + inv11*da)

        norm_const = 2*np.pi * np.sqrt(det)                               # (K,)
        comp_prob = np.exp(-0.5 * maha) / norm_const                      # (K,)

        # mix with weights
        prob_wrap = np.sum(w * comp_prob)                                 # scalar
        prob_total += prob_wrap

    prob_total = max(prob_total, 1e-12)
    nll = -np.log(prob_total)
    return nll


def find_closest_location(locations, point, threshold):
    min_dist = float('inf')
    closest_loc = None
    for loc in locations:
        dist = np.sqrt((loc[0] - point['x'])**2 + (loc[1] - point['y'])**2)
        if dist < min_dist:
            min_dist = dist
            closest_loc = loc
            
    if min_dist > threshold:
        return None
            
    return closest_loc


def preprocess_gmm_data(gmm_data):
    locations = gmm_data[['x', 'y']].drop_duplicates().values
    gmm_dict = {}
    for loc in locations:
        loc_data = gmm_data[(gmm_data['x'] == loc[0]) & (gmm_data['y'] == loc[1])]
        valid_rows = []
        for _, row in loc_data.iterrows():
            cov = np.array([[row['cov1'], row['cov2']], [row['cov3'], row['cov4']]])
            if row['weight'] > 0 and is_valid_covariance(cov) and (not is_too_narrow(cov)):
                row['cov1'], row['cov2'], row['cov3'], row['cov4'] = cov.flatten()
                valid_rows.append(row)
        
        if valid_rows:
            loc_data = pd.DataFrame(valid_rows)
            weights = loc_data['weight'].values
            weights /= weights.sum()  # Normalize the weights
            loc_data['weight'] = weights
            gmm_dict[(loc[0], loc[1])] = loc_data
        
    return gmm_dict


def compute_nll(test_data, MoD_data, threshold):
    gmm_dict = preprocess_gmm_data(MoD_data)

    locations = list(gmm_dict.keys())
    nlls = []

    not_find = 0
    
    for _, point in tqdm(test_data.iterrows(), total=test_data.shape[0]):
        closest_loc = find_closest_location(locations, point, threshold)
        if closest_loc is None:
            print(f"Cannot find the closest location for point {point}")
            not_find += 1
            nlls.append(-np.log(1e-12))
            continue
        else:
            gmm_components = gmm_dict[closest_loc]
            nll = nll_of_point(gmm_components, point)
            if nll is not None:
                nlls.append(nll)
        
    average_nll = np.mean(nlls)
    std_nll = np.std(nlls)
    

    return nlls, not_find, average_nll, std_nll

#############################################################################


def gmm_tensor_to_dataframe(GMM_params_tensor, point):
    GMM_params = GMM_params_tensor.squeeze(0).cpu().numpy()  # shape: (num_components, 6)

    gmm_rows = []
    for row in GMM_params:
        weight, speed_mean, angle_mean, speed_var, angle_var, rho = row

        # Compute covariance matrix
        cov1 = speed_var
        cov4 = angle_var
        cov2 = cov3 = rho * np.sqrt(speed_var * angle_var)

        gmm_rows.append({
            "x": point['x'],
            "y": point['y'],
            "speed": speed_mean,
            "motion_angle": angle_mean,
            "cov1": cov1,
            "cov2": cov2,
            "cov3": cov3,
            "cov4": cov4,
            "weight": weight,
            "motion_ratio": 1.0,  # Assuming motion_ratio is always 1.0
        })

    return pd.DataFrame(gmm_rows)