import os
import math
import math

from typing import Tuple

import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm


def cmpt_energy(data_points: np.ndarray):
    # input: data_points: np.ndarray, shape (n_traj, T, N, D)
    # return: macro_feat_E_T: np.ndarray, shape (n_traj, T)
    assert len(data_points.shape) == 4  # (n_traj, T, N, D)
    a = 4.0
    b = 0.1

    def _logcosh(x: torch.Tensor) -> torch.Tensor:
        ax = torch.abs(x)
        return ax + torch.log1p(torch.exp(-2.0 * ax)) - math.log(2.0)

    X = torch.as_tensor(data_points)
    n_traj, T, N, _ = X.shape

    iu = torch.triu_indices(N, N, offset=1, device=X.device)
    chunk_size = 200  # process 50 trajectories at a time
    energy_chunks = []
    for start in tqdm(range(0, n_traj, chunk_size), desc="Computing energy in chunks"):
        end = min(start + chunk_size, n_traj)
        X_chunk = X[start:end]  # (C, T, N, D)
        diff = X_chunk[..., :, None, :] - X_chunk[..., None, :, :]
        r = torch.linalg.norm(diff, dim=-1)
        rij = r[..., iu[0], iu[1]]

        x = a * (1.0 - rij)
        P = (1.0 / a) * _logcosh(x) + b * (1.0 - rij)
        energy = torch.sum(P, dim=-1)  # (C, T)
        energy_chunks.append(energy)

    energy_all = torch.cat(energy_chunks, dim=0)  # (n_traj, T)
    return energy_all.cpu().numpy()
        
    


def cmpt_Manis(data_points: np.ndarray):
    # input: data_points: np.ndarray, shape (n_traj, T, N, D)
    # return: M_anis: np.ndarray, shape (n_traj, T)
    assert len(data_points.shape) == 4  # (n_traj, T, N, D)
    X = torch.as_tensor(data_points)
    n_traj, T, N, D = X.shape
    if D < 2:
        raise ValueError(f"cmpt_Manis expects D>=2, got D={D}")

    chunk_size = 20  # process trajectories in chunks to limit memory
    manis_chunks = []
    for start in tqdm(range(0, n_traj, chunk_size), desc="Computing Manis in chunks"):
        end = min(start + chunk_size, n_traj)
        X_chunk = X[start:end]  # (C, T, N, D)
        X_centered = X_chunk - X_chunk.mean(dim=2, keepdim=True)
        cov = torch.einsum("ctnd,ctne->ctde", X_centered, X_centered) / float(N)
        eigvals = torch.linalg.eigvalsh(cov)  # ascending, (C, T, D)
        lambda1 = eigvals[..., -1]
        lambda2 = eigvals[..., -2]
        denom = lambda1 + lambda2
        manis = torch.where(denom > 0, (lambda1 - lambda2) / denom, torch.zeros_like(denom))
        manis = torch.clamp(manis, 0.0, 1.0)
        manis_chunks.append(manis)

    manis_all = torch.cat(manis_chunks, dim=0)  # (n_traj, T)
    return manis_all.cpu().numpy()




# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    
    # in_npy = "trajectories_inDistribution.npy"
    # out_npy = "macro_feature_inDistribution.npy"

    # in_npy = "trajectories_outDistribution_3gmm.npy"
    # out_npy = "macro_feature_outDistribution_3gmm.npy"

    in_npy = "trajectories_outDistribution_400N.npy"
    out_npy = "macro_feature_outDistribution_400N.npy"

    # in_npy = "trajectories.npy"
    # out_npy = "macro_feature_new.npy"
    
    
    data_dir = "./data"

    # set input file; if not trajectories.npy, reuse saved normalization stats
    # input_file = os.path.join(data_dir, "trajectories.npy")
    input_file = os.path.join(data_dir, in_npy)

    ## load data
    input_data = np.load(input_file)
    print(f"input_data shape: {input_data.shape}")  # (E,T,N,D)



    print("\n\n get energy over time...")
    macro_feat_E_T = cmpt_energy(
        data_points=input_data, # [E,T,N,D]
    )
    # macro_feat_E_T = cmpt_Manis(
    #     data_points=input_data, # [E,T,N,D]
    # )
    print(f"macro_feat_E_T shape: {macro_feat_E_T.shape}")
    
    

    print("\n\n")
    print(f"Final macro feature shape: {macro_feat_E_T.shape}")
    macro_feat_E_T_1 = macro_feat_E_T[:, :, None]

    # divide by the number of particles
    N_particles = input_data.shape[2]
    # assert N_particles == 300, f"Expect N_particles=300, got {N_particles}"
    macro_feat_E_T_1 = macro_feat_E_T_1 / float(N_particles) / float(N_particles - 1) * 2.0  # normalize by N*(N-1)/2 pairs
    

    # normalize macro_feature to [-1, 1] over all experiments and time steps
    if os.path.basename(input_file) != "trajectories.npy":
        print("## load normalization info ...")
        norm_path = os.path.join(data_dir, "macro_feature_normalization.npz")
        norm_data = np.load(norm_path)
        Z_min = norm_data["Z_min"]
        Z_max = norm_data["Z_max"]
    else:
        Z_min = macro_feat_E_T_1.min(axis=(0, 1), keepdims=True)
        Z_max = macro_feat_E_T_1.max(axis=(0, 1), keepdims=True)
    macro_feat_E_T_1 = 2.0 * (macro_feat_E_T_1 - Z_min) / (Z_max - Z_min) - 1.0
    

    # save to npy
    output_npy_path = os.path.join(data_dir, out_npy)
    np.save(output_npy_path, macro_feat_E_T_1)  


    if os.path.basename(input_file) == "trajectories.npy":
        print("## save normalization info ...")
        normalization_info = {
            "Z_min": Z_min,
            "Z_max": Z_max,
        }
        output_norm_path = os.path.join(data_dir, "macro_feature_normalization.npz")
        np.savez(output_norm_path, **normalization_info)
        print(f"Saved macro feature normalization info to: {output_norm_path}")
