import os
import sys
import numpy as np
from pathlib import Path
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model

__all__ = ["DecisionTreeModule"]


class DecisionTreeModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        criterion: str = 'gini',
        max_depth: int = 10,
        min_samples_split: int = 2,
        min_samples_leaf: int = 1,
        min_weight_fraction_leaf: float = 0.0,
        max_features: str = 'sqrt',
        max_leaf_nodes: int = 10,
        min_impurity_decrease: float = 0.0,
        ccp_alpha: float = 0.0,
        use_class_weight: bool = False,
        random_state: int = 42,
        **kwargs
    ) -> None:
        self.criterion = criterion
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.max_features = max_features
        self.max_leaf_nodes = max_leaf_nodes
        self.min_impurity_decrease = min_impurity_decrease
        self.ccp_alpha = ccp_alpha
        self.use_class_weight = use_class_weight
        self.random_state = random_state
        
        self.device = lib.get_device()
        self.model = None
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None
  
    def create_model_from_params(self, params):
        return DecisionTreeModule(**params)

    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)
        
        dataset_id = config['data'].get('dataset_id')
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )

        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        input_dim = data['input_dim']
        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        self.is_classification = self.num_classes > 1
        self.is_binary = self.num_classes == 2

        training_config = config.get('training', {})
        model_config = config.get('model', {})
        criterion = config.get('criterion') or model_config.get('criterion', 'gini')
        max_depth = config.get('max_depth') or model_config.get('max_depth', None)
        min_samples_split = config.get('min_samples_split') or model_config.get('min_samples_split', 2)
        min_samples_leaf = config.get('min_samples_leaf') or model_config.get('min_samples_leaf', 1)
        min_weight_fraction_leaf = config.get('min_weight_fraction_leaf') or model_config.get('min_weight_fraction_leaf', 0.0)
        max_features = config.get('max_features') or model_config.get('max_features', None)
        max_leaf_nodes = config.get('max_leaf_nodes') or model_config.get('max_leaf_nodes', None)
        min_impurity_decrease = config.get('min_impurity_decrease') or model_config.get('min_impurity_decrease', 0.0)
        ccp_alpha = config.get('ccp_alpha') or model_config.get('ccp_alpha', 0.0)
        use_class_weight = config.get('use_class_weight') or training_config.get('use_class_weight', False)

        if max_features == "auto":
            max_features = None

        dt_kwargs = {
            'criterion': criterion,
            'max_depth': max_depth,
            'min_samples_split': min_samples_split,
            'min_samples_leaf': min_samples_leaf,
            'min_weight_fraction_leaf': min_weight_fraction_leaf,
            'max_features': max_features,
            'max_leaf_nodes': max_leaf_nodes,
            'min_impurity_decrease': min_impurity_decrease,
            'ccp_alpha': ccp_alpha,
            'random_state': config.get('seed', 42)
        }

        if self.is_classification and use_class_weight:
            dt_kwargs['class_weight'] = 'balanced'

        if self.is_classification:
            self.model = DecisionTreeClassifier(**dt_kwargs)
        
        print(f"Training Decision Tree with config: {dt_kwargs}")
        
        self.model.fit(X_train, y_train)
        metrics = evaluate_model(self.model, X_val, y_val, X_test, y_test, 
                               self.is_classification, self.is_binary)
        
        stats = {
            'algorithm': 'DecisionTree',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'n_features': input_dim,
                'tree_depth': int(self.model.get_depth()),
                'n_leaves': int(self.model.get_n_leaves()),
                'n_samples_train': len(y_train)
            }
        }
        
        model_path = output / 'best_model.pkl'
        import pickle
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'config': dt_kwargs
            }, f)

        import json
        with open(output / 'stats.json', 'w') as f:
            json.dump(stats, f, indent=4)
        
        return stats


def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='Run Decision Tree Classifier/Regressor')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = DecisionTreeModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
    print("Training completed. Stats:")
    print(stats)

if __name__ == "__main__":
    main()
