import os
import sys
from pathlib import Path
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.pipeline import Pipeline

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model

__all__ = ["Poly2LogisticModule"]


class PolyLogisticModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        degree: int = 2,
        include_bias: bool = False,
        interaction_only: bool = False,
        
        penalty: str = 'l2',
        C: float = 1.0,
        solver: str = 'lbfgs',
        max_iter: int = 1000,
        tol: float = 1e-4,
        fit_intercept: bool = True,
        intercept_scaling: float = 1.0,
        l1_ratio: float = None,
        
        use_class_weight: bool = False,
        n_jobs: int = None,
        random_state: int = 42,
        **kwargs
    ) -> None:
        self.degree = degree
        self.include_bias = include_bias
        self.interaction_only = interaction_only

        self.penalty = penalty
        self.C = C
        self.solver = solver
        self.max_iter = max_iter
        self.tol = tol
        self.fit_intercept = fit_intercept
        self.intercept_scaling = intercept_scaling
        self.l1_ratio = l1_ratio

        self.use_class_weight = use_class_weight
        self.n_jobs = n_jobs
        self.random_state = random_state
        
        self.device = "cpu"
        self.model = None
        self.poly_features = None
        self.scaler = StandardScaler()
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None

    def create_model_from_params(self, params):
        """Create model instance with given parameters"""
        return PolyLogisticModule(**params)

    def get_optimization_target(self, stats):
        """Extract optimization target from stats"""
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):

        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)
        dataset_id = config['data'].get('dataset_id')
        if dataset_id is None:
            raise ValueError("dataset_id must be specified in config['data']")
        
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )
        
        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        input_dim = data['input_dim']
        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        self.is_classification = True
        self.is_binary = self.num_classes == 2
        
        training_config = config.get('training', {})
        model_config = config.get('model', {})
        
        degree = config.get('degree') or model_config.get('degree', 2)
        include_bias = config.get('include_bias') or model_config.get('include_bias', False)
        interaction_only = config.get('interaction_only') or model_config.get('interaction_only', False)

        penalty = config.get('penalty') or model_config.get('penalty', 'l2')
        C = config.get('C') or model_config.get('C', 1.0)
        solver = config.get('solver') or model_config.get('solver', 'lbfgs')
        max_iter = config.get('max_iter') or model_config.get('max_iter', 1000)
        tol = config.get('tol') or model_config.get('tol', 1e-4)
        fit_intercept = config.get('fit_intercept') or model_config.get('fit_intercept', True)
        intercept_scaling = config.get('intercept_scaling') or model_config.get('intercept_scaling', 1.0)
        l1_ratio = config.get('l1_ratio') or model_config.get('l1_ratio', None)
        use_class_weight = config.get('use_class_weight') or training_config.get('use_class_weight', False)
        n_jobs = config.get('n_jobs') or training_config.get('n_jobs', None)
        
        degree = max(1, min(degree, 3))  
        C = max(C, 1e-8)  
        max_iter = max(100, min(max_iter, 5000))  
        tol = max(tol, 1e-8)  
        if penalty == 'none':
            penalty = None

        if penalty is None and solver not in ['newton-cg', 'lbfgs', 'sag', 'saga']:
            solver = 'lbfgs'
        elif penalty == 'l1' and solver not in ['liblinear', 'saga']:
            solver = 'liblinear'
        elif penalty == 'elasticnet' and solver != 'saga':
            solver = 'saga'

        poly_kwargs = {
            'degree': degree,
            'include_bias': include_bias,
            'interaction_only': interaction_only
        }
        self.poly_features = PolynomialFeatures(**poly_kwargs)

        lr_kwargs = {
            'penalty': penalty,
            'C': C,
            'solver': solver,
            'max_iter': max_iter,
            'tol': tol,
            'fit_intercept': fit_intercept,
            'intercept_scaling': intercept_scaling,
            'warm_start': False,
            'n_jobs': n_jobs,
            'random_state': config.get('seed', 0)
        }

        if penalty == 'elasticnet' and l1_ratio is not None:
            lr_kwargs['l1_ratio'] = l1_ratio

        if use_class_weight:
            lr_kwargs['class_weight'] = 'balanced'

        if solver == 'liblinear' and penalty == 'l2':
            lr_kwargs['dual'] = False  # Usually False is better for n_samples > n_features
        
        X_train_poly = self.poly_features.fit_transform(X_train)
        X_val_poly = self.poly_features.transform(X_val)
        X_test_poly = self.poly_features.transform(X_test)
        
        X_train_poly = self.scaler.fit_transform(X_train_poly)
        X_val_poly = self.scaler.transform(X_val_poly)
        X_test_poly = self.scaler.transform(X_test_poly)
        

        self.model = LogisticRegression(**lr_kwargs)
        self.model.fit(X_train_poly, y_train)
        
        metrics = evaluate_model(X_val_poly, y_val, X_test_poly, y_test)

        stats = {
            'algorithm': 'PolyLogistic',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'n_features_original': input_dim,
                'n_features_poly': X_train_poly.shape[1],
                'n_samples_train': len(y_train),
                'degree': degree,
                'include_bias': include_bias,
                'interaction_only': interaction_only
            }
        }
        
        model_path = output / 'best_model.pkl'
        import pickle
        with open(model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'poly_features': self.poly_features,
                'scaler': self.scaler,
                'poly_config': poly_kwargs,
                'lr_config': lr_kwargs
            }, f)

        import json       
        with open(output / 'stats.json', 'w') as f:
            json.dump(stats, f, indent=4)
        
        return stats

def main():
    """Main function for standalone execution"""
    import argparse
    
    parser = argparse.ArgumentParser(description='Run Polynomial Logistic Regression Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()
    config = lib.load_toml(args.config_path)
    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.parent / config_path.stem
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = PolyLogisticModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )

    print(f"\nTraining completed. Results saved to: {output_dir / 'stats.json'}")
    print("Training completed. Stats:")
    print(stats)
if __name__ == "__main__":
    main()
