import os
import json
import numpy as np
from ase.neighborlist import NeighborList
from ase.io import read
from tqdm import tqdm
from ase.geometry.analysis import Analysis


def angle_df(ats, rmax):
    """
    Calculate the bond angle distribution of a list of ASE Atoms objects.

    Args:
    ats (list): List of ASE Atoms objects
    rmax (float): Maximum distance to consider for bond angle calculation

    Returns:
    np.array: Array of bond angles in radians
    """
    el_ats = ats

    angles = []
    for at in el_ats:
        nl = NeighborList([rmax / 2.0] * len(at), skin=0.0, self_interaction=False, bothways=True)
        nl.update(at)
        pos = at.get_positions()
        lat = at.get_cell()
        for i_at in range(len(at)):
            pos_at = pos[i_at, :]
            indices, offsets = nl.get_neighbors(i_at)
            for i_nei in range(len(indices)):
                pos_i = pos[indices[i_nei], :] + offsets[i_nei, :] @ lat
                for j_nei in range(i_nei):
                    pos_j = pos[indices[j_nei], :] + offsets[j_nei, :] @ lat
                    v_i = pos_i - pos_at + 1.e-8
                    v_j = pos_j - pos_at + 1.e-8
                    angles.append(np.arccos(v_i.dot(v_j) / np.linalg.norm(v_i) / np.linalg.norm(v_j)))
    return np.array(angles)


def rdf(ats, rmax=5.0, nbins=100):
    """
    Calculate the radial distribution function (RDF) of a list of ASE Atoms objects.

    Args:
    ats (list): List of ASE Atoms objects
    rmax (float): Maximum distance to consider for RDF calculation
    nbins (int): Number of bins for RDF calculation

    Returns:
    tuple: Tuple containing the radial distances (r) and the distribution function (df)
    """
    el_ats = ats
    if len(el_ats) == 0:
        return None, None

    ana = Analysis(el_ats)

    rdfs = ana.get_rdf(rmax=rmax, nbins=nbins, return_dists=True)
    df = np.mean([x[0] for x in rdfs], axis=0)
    r = rdfs[0][1]
    return r, df


# Visualization functions moved to scripts/plot_dist.py


class StructureEvaluator:
    """Evaluates material structures by analyzing bond angles and radial distribution functions."""
    
    def __init__(self):
        """
        Initialize the structure evaluator.
        """
        pass
    
    def calculate_adf_from_file(self, extxyz_file, cutoff=3.0, nbins=180, n_samples=None):
        """
        Calculate bond angle distribution histogram from an extxyz file.
        
        Args:
            extxyz_file (str): Path to extxyz file.
            cutoff (float): Cutoff distance for considering atoms as bonded.
            nbins (int): Number of bins for histogram (default 180 for 1-degree bins).
            n_samples (int): Optional, number of samples to consider. If None, use all.
            
        Returns:
            tuple: (angle_centers, distribution) arrays.
        """
        # Read structures from extxyz file
        atoms_list = read(extxyz_file, index=':')
        
        if n_samples is not None:
            atoms_list = atoms_list[:n_samples]
        
        # Filter out ghost atoms (symbol 'X')
        filtered_atoms_list = [atoms[[atom.symbol != 'X' for atom in atoms]] for atoms in atoms_list]
        
        # Calculate angles in degrees
        angles = angle_df(filtered_atoms_list, cutoff) / np.pi * 180.0
        
        # Create histogram
        hist, bin_edges = np.histogram(angles, bins=nbins, range=(0, 180), density=True)
        
        # Calculate bin centers
        angle_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        
        return angle_centers, hist
    
    def calculate_rdf_from_file(self, extxyz_file, cutoff=5.0, nbins=100, n_samples=None):
        """
        Calculate radial distribution function from an extxyz file.
        
        Args:
            extxyz_file (str): Path to extxyz file.
            cutoff (float): Maximum distance to consider for RDF calculation.
            nbins (int): Number of bins for RDF calculation.
            n_samples (int): Optional, number of samples to consider. If None, use all.
            
        Returns:
            tuple: (distances, distribution) arrays.
        """
        # Read structures from extxyz file
        atoms_list = read(extxyz_file, index=':')
        
        if n_samples is not None:
            atoms_list = atoms_list[:n_samples]
        
        # Filter out ghost atoms (symbol 'X')
        filtered_atoms_list = [atoms[[atom.symbol != 'X' for atom in atoms]] for atoms in atoms_list]
        
        # Calculate RDF
        rad, df = rdf(filtered_atoms_list, cutoff, nbins)
        
        return rad, df
    
    def save_data_to_json(self, data, data_type, extxyz_file):
        """
        Save calculated data to JSON file in the same directory as the extxyz file.
        
        Args:
            data: Data to save (angles array for ADF, or (rad, df) tuple for RDF).
            data_type (str): Type of data ('adf' or 'rdf').
            extxyz_file (str): Path to the extxyz file.
        """
        # Get directory and create output filename
        directory = os.path.dirname(extxyz_file)
        base_name = os.path.splitext(os.path.basename(extxyz_file))[0]
        output_file = os.path.join(directory, f'{base_name}-{data_type}.json')
        
        # Create output directory if it doesn't exist
        os.makedirs(directory, exist_ok=True)
        
        # Format data based on type
        if data_type == 'adf':
            # ADF data: angle centers and distribution arrays (histogram format)
            angles, distribution = data
            json_data = {
                "angles": angles.tolist(),
                "distribution": distribution.tolist()
            }
        elif data_type == 'rdf':
            # RDF data: distances and distribution arrays
            rad, df = data
            json_data = {
                "distances": rad.tolist(),
                "distribution": df.tolist()
            }
        else:
            raise ValueError(f"Unknown data type: {data_type}")
        
        # Write to JSON file
        with open(output_file, 'w') as f:
            json.dump(json_data, f, indent=2)
        
        return output_file
    
    def evaluate_files(self, extxyz_files, adf_config, rdf_config):
        """
        Evaluate multiple extxyz files by calculating ADF and RDF and saving to JSON.
        
        Args:
            extxyz_files (list): List of paths to extxyz files.
            adf_config (dict): Configuration for ADF calculation (cutoff, n_samples).
            rdf_config (dict): Configuration for RDF calculation (cutoff, n_samples).
        """
        for extxyz_file in tqdm(extxyz_files, desc="Evaluating structure files"):
            # Calculate and save ADF
            angle_centers, adf_distribution = self.calculate_adf_from_file(extxyz_file, **adf_config)
            adf_file = self.save_data_to_json((angle_centers, adf_distribution), 'adf', extxyz_file)
            
            # Calculate and save RDF
            rad, df = self.calculate_rdf_from_file(extxyz_file, **rdf_config)
            if rad is not None:
                rdf_file = self.save_data_to_json((rad, df), 'rdf', extxyz_file)
            
            tqdm.write(f"Processed {os.path.basename(extxyz_file)}: ADF and RDF data saved")
