import os
import sys
import numpy as np
from pathlib import Path
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.preprocessing import StandardScaler

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model 

__all__ = ["ElasticNetModule"]


class ElasticNetModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        alpha: float = 1.0,  
        l1_ratio: float = 0.5,  

        use_sgd: bool = False, 

        tol: float = 1e-4,
        fit_intercept: bool = True,
        max_iter: int = 1000,
        random_state: int = 42,

        learning_rate: str = 'optimal',
        eta0: float = 0.01,
        early_stopping: bool = True,
        validation_fraction: float = 0.1,
        n_iter_no_change: int = 5,
        intercept_scaling: float = 1.0,

        use_class_weight: bool = False,
        n_jobs: int = None,
        **kwargs
    ) -> None:
        self.alpha = alpha
        self.l1_ratio = l1_ratio
        self.use_sgd = use_sgd
        self.tol = tol
        self.fit_intercept = fit_intercept
        self.max_iter = max_iter
        self.random_state = random_state

        self.learning_rate = learning_rate
        self.eta0 = eta0
        self.early_stopping = early_stopping
        self.validation_fraction = validation_fraction
        self.n_iter_no_change = n_iter_no_change

        self.intercept_scaling = intercept_scaling

        self.use_class_weight = use_class_weight
        self.n_jobs = n_jobs
        
        self.device = "cpu"
        self.model = None
        self.scaler = StandardScaler()
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None

    def create_model_from_params(self, params):
        return ElasticNetModule(**params)

    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)

        dataset_id = config['data'].get('dataset_id')
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )

        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        input_dim = data['input_dim']
        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        self.is_classification = True
        self.is_binary = self.num_classes == 2
        
        training_config = config.get('training', {})
        model_config = config.get('model', {})

        alpha = config.get('alpha') or model_config.get('alpha', 1.0)
        l1_ratio = config.get('l1_ratio') or model_config.get('l1_ratio', 0.5)
        use_sgd = config.get('use_sgd') or model_config.get('use_sgd', False)
        max_iter = config.get('max_iter') or model_config.get('max_iter', 1000)
        tol = config.get('tol') or model_config.get('tol', 1e-4)
        fit_intercept = config.get('fit_intercept') or model_config.get('fit_intercept', True)

        learning_rate = config.get('learning_rate') or model_config.get('learning_rate', 'optimal')
        eta0 = config.get('eta0') or model_config.get('eta0', 0.01)
        validation_fraction = config.get('validation_fraction') or model_config.get('validation_fraction', 0.1)
        n_iter_no_change = config.get('n_iter_no_change') or model_config.get('n_iter_no_change', 5)
        intercept_scaling = config.get('intercept_scaling') or model_config.get('intercept_scaling', 1.0)

        use_class_weight = config.get('use_class_weight') or training_config.get('use_class_weight', False)
        n_jobs = config.get('n_jobs') or training_config.get('n_jobs', None)

        alpha = max(alpha, 1e-8)  
        l1_ratio = max(0.0, min(1.0, l1_ratio)) 
        max_iter = max(100, min(max_iter, 5000)) 
        tol = max(tol, 1e-8)  
        
        if use_sgd:
            eta0 = max(eta0, 1e-6)  
            validation_fraction = max(0.01, min(validation_fraction, 0.5))  
            n_iter_no_change = max(3, min(n_iter_no_change, 50))  
            
            model_kwargs = {
                'loss': 'log_loss',  
                'penalty': 'elasticnet',
                'alpha': alpha,
                'l1_ratio': l1_ratio,
                'fit_intercept': fit_intercept,
                'max_iter': max_iter,
                'tol': tol,
                'learning_rate': learning_rate,
                'eta0': eta0,
                'early_stopping': True,
                'validation_fraction': validation_fraction,
                'n_iter_no_change': n_iter_no_change,
                'random_state': config.get('seed', 42),
                'n_jobs': n_jobs
            }
            
            if use_class_weight:
                model_kwargs['class_weight'] = 'balanced'
            
            self.model = SGDClassifier(**model_kwargs)
            model_type = "SGDClassifier"
            
        else:
            C = 1.0 / alpha if alpha > 0 else 1.0
            C = max(C, 1e-8)  
            intercept_scaling = max(intercept_scaling, 0.01)  
            
            model_kwargs = {
                'penalty': 'elasticnet',
                'C': C,
                'l1_ratio': l1_ratio,
                'solver': 'saga', 
                'fit_intercept': fit_intercept,
                'intercept_scaling': intercept_scaling,
                'max_iter': max_iter,
                'tol': tol,
                'multi_class': 'auto',
                'warm_start': False,
                'n_jobs': n_jobs,
                'random_state': config.get('seed', 42)
            }
            
            if use_class_weight:
                model_kwargs['class_weight'] = 'balanced'
            
            self.model = LogisticRegression(**model_kwargs)
            model_type = "LogisticRegression"
        
        self.model.fit(X_train, y_train)     
        
        metrics = evaluate_model(X_val, y_val, X_test, y_test)
        stats = {
            'algorithm': f'ElasticNet_{model_type}',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'model_type': model_type,
                'n_features': input_dim,
                'n_samples_train': len(y_train),
                'alpha': alpha,
                'l1_ratio': l1_ratio,
                'use_sgd': use_sgd
            },
            'metrics': metrics
        }

        model_path = output / 'best_model.pkl'
        import pickle
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'scaler': self.scaler,
                'config': model_kwargs
            }, f)
        
        import json
        
        with open(output / 'stats.json', 'w') as f:
            json.dump(stats, f, indent=4)
        
        return stats
    
def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='Run ElasticNet Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()
    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'): 
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = ElasticNetModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
    print("Training completed. Stats:")
    print(stats)
if __name__ == "__main__":
    main()
