import os
import sys
import numpy as np
from pathlib import Path
from sklearn.ensemble import RandomForestClassifier
# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lib.metrics import evaluate_model
import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset

__all__ = ["RandomForestModule"]


class RandomForestModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        n_estimators: int = 100,
        criterion: str = 'gini',
        max_depth: int = None,
        min_samples_split: int = 2,
        min_samples_leaf: int = 1,
        min_weight_fraction_leaf: float = 0.0,
        max_features: str = 'sqrt',
        max_leaf_nodes: int = None,
        min_impurity_decrease: float = 0.0,
        bootstrap: bool = True,
        oob_score: bool = False,
        n_jobs: int = -1,
        ccp_alpha: float = 0.0,
        max_samples: float = None,
        use_class_weight: bool = False,
        random_state: int = 42,
        **kwargs
    ) -> None:
        self.n_estimators = n_estimators
        self.criterion = criterion
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.max_features = max_features
        self.max_leaf_nodes = max_leaf_nodes
        self.min_impurity_decrease = min_impurity_decrease
        self.bootstrap = bootstrap
        self.oob_score = oob_score
        self.n_jobs = n_jobs
        self.ccp_alpha = ccp_alpha
        self.max_samples = max_samples
        self.use_class_weight = use_class_weight
        self.random_state = random_state
        
        self.device = "cpu"
        self.model = None
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None

    def create_model_from_params(self, params):
        return RandomForestModule(**params)

    def get_optimization_target(self, stats):
        """Extract optimization target from stats"""
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)
        dataset_id = config['data'].get('dataset_id')

        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )

        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        input_dim = data['input_dim']
        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        self.is_classification = self.num_classes > 1
        self.is_binary = self.num_classes == 2

        training_config = config.get('training', {})
        model_config = config.get('model', {})
        n_estimators = config.get('n_estimators') or model_config.get('n_estimators', 100)
        criterion = config.get('criterion') or model_config.get('criterion', 'gini')
        max_depth = config.get('max_depth') or model_config.get('max_depth', None)
        min_samples_split = config.get('min_samples_split') or model_config.get('min_samples_split', 2)
        min_samples_leaf = config.get('min_samples_leaf') or model_config.get('min_samples_leaf', 1)
        min_weight_fraction_leaf = config.get('min_weight_fraction_leaf') or model_config.get('min_weight_fraction_leaf', 0.0)
        max_features = config.get('max_features') or model_config.get('max_features', 'sqrt')
        max_leaf_nodes = config.get('max_leaf_nodes') or model_config.get('max_leaf_nodes', None)
        min_impurity_decrease = config.get('min_impurity_decrease') or model_config.get('min_impurity_decrease', 0.0)
        bootstrap = config.get('bootstrap') or model_config.get('bootstrap', True)
        oob_score = config.get('oob_score') or model_config.get('oob_score', False)
        ccp_alpha = config.get('ccp_alpha') or model_config.get('ccp_alpha', 0.0)
        max_samples = config.get('max_samples') or model_config.get('max_samples', None)
        use_class_weight = config.get('use_class_weight') or training_config.get('use_class_weight', False)
        n_jobs = config.get('n_jobs') or training_config.get('n_jobs', -1)

        if max_features == "auto":
            max_features = None

        rf_kwargs = {
            'n_estimators': n_estimators,
            'criterion': criterion,
            'max_depth': max_depth,
            'min_samples_split': min_samples_split,
            'min_samples_leaf': min_samples_leaf,
            'min_weight_fraction_leaf': min_weight_fraction_leaf,
            'max_features': max_features,
            'max_leaf_nodes': max_leaf_nodes,
            'min_impurity_decrease': min_impurity_decrease,
            'bootstrap': bootstrap,
            'oob_score': oob_score,
            'n_jobs': n_jobs,
            'ccp_alpha': ccp_alpha,
            'max_samples': max_samples,
            'random_state': config.get('seed', 42)
        }

        if self.is_classification and use_class_weight:
            rf_kwargs['class_weight'] = 'balanced'
        
        self.model = RandomForestClassifier(**rf_kwargs)
        print(f"Training Random Forest with config: {rf_kwargs}")
        self.model.fit(X_train, y_train)

        metrics = evaluate_model(X_val, y_val, X_test, y_test)
        stats = {
            'algorithm': 'RandomForest',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'n_features': input_dim,
                'n_estimators': self.model.n_estimators,
                'feature_importances': self.model.feature_importances_.tolist(),
                'n_samples_train': len(y_train),
                'oob_score': self.model.oob_score_ if hasattr(self.model, 'oob_score_') else None
            },
            'metrics': metrics
        }
        model_path = output / 'best_model.pkl'
        import pickle
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'config': rf_kwargs
            }, f)


        import json  
        with open(output / 'stats.json', 'w') as f:
            json.dump(stats, f, indent=4)
        
        return stats

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Run Random Forest Classifier/Regressor')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = RandomForestModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
    print("Training completed. Stats:")
    print(stats)

if __name__ == "__main__":
    main()
