import os
import numpy as np
import json
from collections import Counter
import pickle
import xgboost as xgb
from sklearn.metrics import mean_squared_error, r2_score
from ase.io import read
from tqdm import tqdm


class CompositionEvaluator:
    """Evaluates material samples by analyzing their elemental composition."""
    
    def __init__(self, save_dir='cache/comp_eval'):
        """
        Initialize the composition evaluator.
        
        Args:
            save_dir (str): Directory to save trained models and evaluation results.
        """
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.models = {}  # Property name -> trained XGBoost model
        self.element_map = {}  # Maps element numbers to indices in feature vectors
    
    def extract_composition_features(self, samples):
        """
        Extract composition features (element proportions) from samples.
        
        Args:
            samples (list): List of Sample objects.
            
        Returns:
            tuple: (features array, list of element numbers found)
        """
        # First collect all unique elements across all samples
        all_elements = set()
        for sample in samples:
            elements = sample.get_elements().cpu().numpy()
            # Filter out ghost atoms (element number 0)
            elements = elements[elements != 0]
            all_elements.update(elements)
        
        # Create a mapping from element number to feature index
        all_elements = sorted(list(all_elements))
        element_to_idx = {elem: i for i, elem in enumerate(all_elements)}
        self.element_map = element_to_idx
        
        # Extract composition features for each sample
        features = np.zeros((len(samples), len(element_to_idx)))
        for i, sample in enumerate(samples):
            elements = sample.get_elements().cpu().numpy()
            # Filter out ghost atoms (element number 0)
            elements = elements[elements != 0]
            element_counts = Counter(elements)
            total_atoms = len(elements)
            
            # Calculate proportions
            for element, count in element_counts.items():
                if element in element_to_idx:  # Should always be true
                    features[i, element_to_idx[element]] = count / total_atoms
        
        return features, all_elements
    
    def extract_composition_features_from_atoms(self, atoms_list):
        """
        Extract composition features (element proportions) from ASE Atoms objects.
        Uses the existing element_map from the trained model.
        
        Args:
            atoms_list (list): List of ASE Atoms objects.
            
        Returns:
            tuple: (features array, list of element numbers found)
        """
        # Use the existing element_map from the trained model
        if not self.element_map:
            raise ValueError("No element mapping found. Load a trained model first.")
        
        # Extract composition features for each structure
        n_features = len(self.element_map)
        features = np.zeros((len(atoms_list), n_features))
        
        for i, atoms in enumerate(atoms_list):
            # Filter out ghost atoms (symbol 'X') and get atomic numbers
            atomic_numbers = [atom.number for atom in atoms if atom.symbol != 'X']
            element_counts = Counter(atomic_numbers)
            total_atoms = len(atomic_numbers)
            
            # Calculate proportions using the existing element_map
            for atomic_number, count in element_counts.items():
                if atomic_number in self.element_map:
                    feature_idx = self.element_map[atomic_number]
                    features[i, feature_idx] = count / total_atoms
        
        return features, list(self.element_map.keys())
    
    def train_model(self, dataset, property_name, model_params=None):
        """
        Train an XGBoost regressor to predict a property from composition.
        
        Args:
            dataset (MaterialDataset): Dataset containing samples and properties.
            property_name (str): Name of the property to predict.
            model_params (dict, optional): Parameters for the XGBoost model.
            
        Returns:
            xgb.XGBRegressor: Trained XGBoost model.
        """
        samples = dataset.samples
        
        # Extract features
        X, elements = self.extract_composition_features(samples)
        
        # Extract target property values
        y = []
        for sample in samples:
            if sample.properties and property_name in sample.properties:
                prop_value = sample.properties[property_name]
                if prop_value is not None:
                    y.append(float(prop_value.cpu().numpy()))
                else:
                    y.append(np.nan)
            else:
                y.append(np.nan)
        
        y = np.array(y)
        valid_indices = ~np.isnan(y)
        
        if not np.any(valid_indices):
            raise ValueError(f"No valid {property_name} values found in the dataset")
        
        X_train = X[valid_indices]
        y_train = y[valid_indices]
        
        # Set default model parameters if not provided
        if model_params is None:
            model_params = {
                'n_estimators': 100,
                'max_depth': 3
            }
        
        # Train XGBoost model
        model = xgb.XGBRegressor(**model_params)
        model.fit(X_train, y_train)
        
        # Save model metadata
        model_info = {
            'property': property_name,
            'elements': list(map(int, elements)),
            'element_map': {int(k): v for k, v in self.element_map.items()},
            'train_samples': int(np.sum(valid_indices)),
            'feature_importances': model.feature_importances_.tolist(),
        }
        
        with open(os.path.join(self.save_dir, f"{property_name}_model_info.json"), 'w') as f:
            json.dump(model_info, f, indent=2)
        
        # Save the trained model
        model_path = os.path.join(self.save_dir, f"{property_name}_model.pkl")
        with open(model_path, 'wb') as f:
            pickle.dump(model, f)
        
        self.models[property_name] = model
        
        # Calculate training metrics
        y_pred = model.predict(X_train)
        mse = mean_squared_error(y_train, y_pred)
        r2 = r2_score(y_train, y_pred)
        
        print(f"Model for {property_name} trained on {len(X_train)} samples")
        print(f"Training MSE: {mse:.4f}, R²: {r2:.4f}")
        
        return model
    
    def load_model(self, property_name):
        """
        Load a trained model for a specific property.
        
        Args:
            property_name (str): Name of the property to load model for.
            
        Returns:
            xgb.XGBRegressor: Loaded XGBoost model.
        """
        model_path = os.path.join(self.save_dir, f"{property_name}_model.pkl")
        info_path = os.path.join(self.save_dir, f"{property_name}_model_info.json")
        
        if not os.path.exists(model_path) or not os.path.exists(info_path):
            raise FileNotFoundError(f"Model files for {property_name} not found")
        
        # Load model info to restore element mapping
        with open(info_path, 'r') as f:
            model_info = json.load(f)
            self.element_map = {int(k): v for k, v in model_info['element_map'].items()}
        
        # Load model
        with open(model_path, 'rb') as f:
            model = pickle.load(f)
            self.models[property_name] = model
            
        return model
    
    def predict_property(self, samples, property_name):
        """
        Predict properties for samples using a trained model.
        
        Args:
            samples (list): List of Sample objects.
            property_name (str): Name of the property to predict.
            
        Returns:
            np.ndarray: Predicted property values.
        """
        if property_name not in self.models:
            try:
                self.load_model(property_name)
            except FileNotFoundError:
                raise ValueError(f"No trained model found for {property_name}")
        
        # Extract composition features
        X, _ = self.extract_composition_features(samples)
        
        # Get predictions
        predictions = self.models[property_name].predict(X)
        
        return predictions
    
    def evaluate_samples(self, samples, property_name, dataset_path=None):
        """
        Evaluate samples by comparing their target properties with predicted properties.
        
        Args:
            samples (list): List of Sample objects.
            property_name (str): Property name in sample.properties to use as target.
            dataset_path (str, optional): Path to the input dataset file for saving predictions.
            
        Returns:
            tuple: (target values, predicted values, metrics dictionary)
        """
        
        # Get predictions for all samples
        predicted_all = self.predict_property(samples, property_name)
        
        # Save predictions to JSON if dataset path is provided
        if dataset_path:
            self.save_predictions_to_json(predicted_all, property_name, dataset_path)
        
        # Extract target properties
        target_values = []
        for sample in samples:
            prop_value = sample.properties[property_name]
            if prop_value is not None:
                target_values.append(float(prop_value.cpu().numpy()))
            else:
                target_values.append(np.nan)
        
        target_values = np.array(target_values)
        valid_indices = ~np.isnan(target_values)
        
        if not np.any(valid_indices):
            raise ValueError(f"No valid {property_name} values found in the samples")
        
        # Use only valid samples for evaluation
        target_values = target_values[valid_indices]
        predicted = predicted_all[valid_indices]
        
        # Calculate metrics
        mse = mean_squared_error(target_values, predicted)
        r2 = r2_score(target_values, predicted)
        
        metrics = {
            'mse': mse,
            'rmse': np.sqrt(mse),
            'r2': r2,
            'mean_error': np.mean(np.abs(predicted - target_values)),
            'std_error': np.std(np.abs(predicted - target_values)),
            'num_samples': len(np.where(valid_indices)[0])
        }
        
        return target_values, predicted, metrics
    
    def save_predictions_to_json(self, predictions, property_name, dataset_path):
        """
        Save predicted property values to JSON file in the same directory as the input dataset.
        
        Args:
            predictions (np.ndarray): Predicted property values.
            property_name (str): Name of the property.
            dataset_path (str): Path to the input dataset file.
        """
        # Get directory and create output filename
        directory = os.path.dirname(dataset_path)
        base_name = os.path.splitext(os.path.basename(dataset_path))[0]
        output_file = os.path.join(directory, f'{base_name}-{property_name}.json')
        
        # Create output directory if it doesn't exist
        os.makedirs(directory, exist_ok=True)
        
        # Format predictions as list of dictionaries
        data = [{property_name: float(pred)} for pred in predictions]
        
        # Write to JSON file
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)
        
        # print(f"Predictions for {property_name} saved to {output_file}")
    
    def predict_property_from_file(self, extxyz_file, property_name):
        """
        Predict properties for structures in an extxyz file using a trained model.
        
        Args:
            extxyz_file (str): Path to extxyz file.
            property_name (str): Name of the property to predict.
            
        Returns:
            np.ndarray: Predicted property values.
        """
        if property_name not in self.models:
            try:
                self.load_model(property_name)
            except FileNotFoundError:
                raise ValueError(f"No trained model found for {property_name}")
        
        # Read structures from extxyz file
        atoms_list = read(extxyz_file, index=':')
        
        # Extract composition features
        X, _ = self.extract_composition_features_from_atoms(atoms_list)
        
        # Get predictions
        predictions = self.models[property_name].predict(X)
        
        # Save predictions to JSON
        self.save_predictions_to_json_simple(predictions, property_name, extxyz_file)
        
        return predictions
    
    def save_predictions_to_json_simple(self, predictions, property_name, extxyz_file):
        """
        Save predicted property values to JSON file in the same directory as the extxyz file.
        
        Args:
            predictions (np.ndarray): Predicted property values.
            property_name (str): Name of the property.
            extxyz_file (str): Path to the extxyz file.
        """
        # Get directory and create output filename
        directory = os.path.dirname(extxyz_file)
        base_name = os.path.splitext(os.path.basename(extxyz_file))[0]
        # Clean property name for filename (remove spaces and brackets)
        clean_property_name = property_name.replace(' ', '_').replace('[', '').replace(']', '')
        output_file = os.path.join(directory, f'{base_name}-{clean_property_name}.json')
        
        # Create output directory if it doesn't exist
        os.makedirs(directory, exist_ok=True)
        
        # Format predictions as list of dictionaries
        data = [{property_name: float(pred)} for pred in predictions]
        
        # Write to JSON file
        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)
        
        # print(f"Predictions for {property_name} saved to {output_file}")
    
    def evaluate_files(self, extxyz_files, property_name):
        """
        Evaluate multiple extxyz files by predicting properties and saving to JSON.
        
        Args:
            extxyz_files (list): List of paths to extxyz files.
            property_name (str): Name of the property to predict.
        """
        if property_name not in self.models:
            try:
                self.load_model(property_name)
            except FileNotFoundError:
                raise ValueError(f"No trained model found for {property_name}")
        
        for extxyz_file in tqdm(extxyz_files, desc=f"Evaluating {property_name}"):
            predictions = self.predict_property_from_file(extxyz_file, property_name)
            tqdm.write(f"Predicted {len(predictions)} samples from {os.path.basename(extxyz_file)}")