import argparse
from ase.io import read
from tqdm import tqdm
import numpy as np
from collections import defaultdict

def parse_arguments():
    parser = argparse.ArgumentParser(description='Analyze bond stability in molecular dynamics trajectory')
    parser.add_argument('--traj_file', type=str, required=True, help='Path to trajectory file (.traj)')
    parser.add_argument('--ref_file', type=str, required=True, help='Path to reference file (.xyz)')
    parser.add_argument('--bond_threshold', type=float, default=1.8, help='Bond length threshold in Å (default: 1.8)')
    parser.add_argument('--tolerance', type=float, default=0.5, help='Bond length tolerance (default: 0.5)')
    parser.add_argument('--timestep_fs', type=float, default=1.0, help='Timestep in femtoseconds (default: 1.0)')
    return parser.parse_args()

def analyze_bond_lengths(ref_file, bond_threshold):
    # Read all molecules from the reference file
    molecules = read(ref_file, index=':', format='xyz')
    
    # Dictionary to store bond lengths by bond type
    bond_lengths = defaultdict(list)
    # List to store edge indices (atom index pairs for bonds)
    edge_indices = []
    
    bond_index = {}
    for mol_idx, mol in enumerate(molecules):
        # Get atomic symbols and positions
        symbols = mol.get_chemical_symbols()
        positions = mol.get_positions()
        
        # Analyze bonds
        for i in range(len(mol)):
            for j in range(i + 1, len(mol)):
                # Calculate distance between atoms i and j
                dist = np.linalg.norm(positions[i] - positions[j])
                
                # Check if distance is within bonding threshold
                if dist < bond_threshold:
                    # Store edge index: (molecule_index, atom_i, atom_j)
                    edge_indices.append((mol_idx, i, j))
                    # Store bond length with bond type
                    bond_type = f"{i}-{j}"
                    bond_lengths[bond_type].append(dist)
    
    # Calculate average bond lengths for consistent bonds
    mol_len = len(molecules)
    true_bonds = {}
    true_lengths = []
    edge_index = []
    for bond_type, lengths in bond_lengths.items():
        if len(lengths) == mol_len:
            a, b = map(int, bond_type.split('-'))
            true_bonds[bond_type] = lengths
            edge_index.append([a, b])
            true_lengths.append(np.mean(lengths))
    
    return np.array(edge_index), true_lengths

def check_edge_lengths(reference_lengths, calculated_lengths, tolerance):
    if len(reference_lengths) != len(calculated_lengths):
        raise ValueError("Reference and calculated lengths lists must have the same length")
    
    for ref, calc in zip(reference_lengths, calculated_lengths):
        if not (ref - tolerance <= calc <= ref + tolerance):
            return False
    return True

def analyze_trajectory(traj_file, edge_index, true_lengths, tolerance, timestep_fs):
    # Read trajectory
    traj = read(traj_file, index=':')
    
    stable_frames = len(traj)
    for k, atoms in enumerate(tqdm(traj)):
        pos = atoms.get_positions()
        p1 = pos[edge_index[:, 0]]  # Start points
        p2 = pos[edge_index[:, 1]]  # End points
        edges = np.sqrt(np.sum((p2 - p1) ** 2, axis=1))
        result = check_edge_lengths(true_lengths, edges, tolerance)
        if not result:
            stable_frames = k
            break
    
    # Calculate stable time
    stable_time_ps = (stable_frames - 1) * timestep_fs / 1000.0 if stable_frames > 0 else 0.0
    stable_time_fs = (stable_frames - 1) * timestep_fs if stable_frames > 0 else 0.0
    
    return stable_time_ps, stable_time_fs

def main():
    args = parse_arguments()
    
    # Analyze reference file for bond lengths
    edge_index, true_lengths = analyze_bond_lengths(args.ref_file, args.bond_threshold)
    print(edge_index)
    # Analyze trajectory stability
    stable_time_ps, stable_time_fs = analyze_trajectory(
        args.traj_file, edge_index, true_lengths, args.tolerance, args.timestep_fs
    )
    
    print(f"Stability of the trajectory: {stable_time_ps} ps")
    print(f"Stability of the trajectory: {stable_time_fs} fs")

if __name__ == "__main__":
    main()