import os
import sys
import tomlkit 
import warnings
from typing import Optional, Tuple, Dict, Any, List
from pathlib import Path
import torch
import numpy as np
from torch.utils.data import DataLoader
from lightgbm import LGBMClassifier
import lightgbm as lgb
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model

__all__ = ["LightGBMClassifierModule"]

class LightGBMClassifierModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        lr: float = 0.1,
        n_estimators: int = 100,
        batch_size: Optional[int] = None,
        val_batch_size: Optional[int] = None,
        objective: str = "binary",
        boosting_type: str = "gbdt",
        patience: int = 10,
        use_class_weight: bool = False,
        max_grad_norm: Optional[float] = None,
        num_class: Optional[int] = None,
        use_gpu: bool = True,
        random_state: int = 42,
        **lgb_kwargs: Any,

    ) -> None:
        self.lr = lr
        self.n_estimators = n_estimators
        self.batch_size = batch_size
        self.val_batch_size = val_batch_size
        self.patience = patience
        self.use_class_weight = use_class_weight
        self.num_class = num_class
        self.use_gpu = use_gpu

        self.device = torch.device('cuda' if torch.cuda.is_available() and use_gpu else 'cpu')

        if use_gpu and torch.cuda.is_available():
            lgb_kwargs.setdefault('device_type', 'gpu')
            lgb_kwargs.setdefault('gpu_platform_id', 0)
            lgb_kwargs.setdefault('gpu_device_id', 0)
        
        if objective == "auto":
            if num_class and num_class > 2:
                objective = "multiclass"
            else:
                objective = "binary"

        self._model = LGBMClassifier(
            learning_rate=lr,
            n_estimators=n_estimators,
            objective=objective,
            boosting_type=boosting_type,
            num_class=num_class if (num_class and num_class > 2) else None,
            random_state=random_state,
            **lgb_kwargs,
        )

        self.classes_: Optional[np.ndarray] = None
    
    def fit(
        self,
        train_loader: DataLoader,
        val_loader: Optional[DataLoader] = None,
        verbose: bool = False,
    ) -> "LightGBMClassifierModule":
        X_train, y_train = self._stack_from_loader(train_loader)
        eval_set = None
        if val_loader is not None:
            X_val, y_val = self._stack_from_loader(val_loader)
            eval_set = [(X_val, y_val)]

        if self.use_class_weight:
            from sklearn.utils.class_weight import compute_sample_weight
            sample_weight = compute_sample_weight('balanced', y_train)

        else:
            sample_weight = None

        callbacks = [lgb.early_stopping(stopping_rounds=self.patience, verbose=False)]
        self._model.fit(
            X_train,
            y_train,
            sample_weight=sample_weight,
            eval_set=eval_set,
            eval_metric='logloss',
            callbacks=callbacks
        )
        self.classes_ = self._model.classes_
        return self

    def predict(self, X_tensor: torch.Tensor, as_prob: bool = False) -> torch.Tensor:
        X_np = X_tensor.detach().cpu().numpy()
        if as_prob:
            preds = self._model.predict_proba(X_np)
        else:
            preds = self._model.predict(X_np)
        return torch.from_numpy(preds)

    def save(self, path: str) -> None:
        os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
        import joblib
        joblib.dump(self._model, path)

    def load(self, path: str) -> "LightGBMClassifierModule":
        import joblib
        self._model = joblib.load(path)
        self.classes_ = getattr(self._model, "classes_", None)
        return self

    def _stack_from_loader(self, loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
        X_list: List[torch.Tensor] = []
        y_list: List[torch.Tensor] = []
        for xb, yb in loader:
            X_list.append(xb.detach().cpu())
            y_list.append(yb.detach().cpu())
        X = torch.vstack(X_list).numpy().astype(np.float32)
        y = torch.hstack(y_list).numpy().astype(np.int64)
        return X, y

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)

        dataset_id = config['data'].get('dataset_id')
        if dataset_id is None:
            raise ValueError("dataset_id must be specified in config['data']")
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )

        X_train = data['X_train']
        X_val = data['X_val'] 
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        input_dim = data['input_dim']
        num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        is_classification = num_classes > 1
        is_binary = num_classes == 2

        def get_param(name, default, config_sections=None):
            if config_sections is None:
                config_sections = ['training', 'model']
            if name in config:
                return config[name]

            for section in config_sections:
                section_config = config.get(section, {})
                if name in section_config:
                    return section_config[name]
            
            return default

        lr = get_param('lr', 0.1)
        n_estimators = get_param('n_estimators', 1000)
        patience = get_param('patience', 50)
        use_class_weight = get_param('use_class_weight', False)
        max_depth = get_param('max_depth', -1)
        num_leaves = get_param('num_leaves', 31)
        min_child_weight = get_param('min_child_weight', 1e-3)
        min_child_samples = get_param('min_child_samples', 20)
        subsample = get_param('subsample', 1.0)
        colsample_bytree = get_param('colsample_bytree', 1.0)
        reg_alpha = get_param('reg_alpha', 0)
        reg_lambda = get_param('reg_lambda', 0)
        min_split_gain = get_param('min_split_gain', 0.0)
        subsample_freq = get_param('subsample_freq', 0)
        boosting_type = get_param('boosting_type', 'gbdt')
        
        # Build LightGBM parameters
        lgb_kwargs = {
            'max_depth': max_depth,
            'num_leaves': num_leaves,
            'min_child_weight': min_child_weight,
            'min_child_samples': min_child_samples,
            'colsample_bytree': colsample_bytree,
            'reg_alpha': reg_alpha,
            'reg_lambda': reg_lambda,
            'min_split_gain': min_split_gain,
            'boosting_type': boosting_type,
            'verbosity': -1,  
        }

        if boosting_type != 'goss':
            lgb_kwargs['subsample'] = subsample
            lgb_kwargs['subsample_freq'] = subsample_freq
        
        if torch.cuda.is_available() and self.use_gpu:
            lgb_kwargs.update({
                'device_type': 'gpu',
                'gpu_platform_id': 0,
                'gpu_device_id': 0
            })
        
        if is_binary:
            objective = "binary"
        else:
            objective = "multiclass"
            lgb_kwargs['num_class'] = num_classes
            
        self._model = LGBMClassifier(
            learning_rate=lr,
            n_estimators=n_estimators,
            objective=objective,
            random_state=config.get('seed', 0),
            **lgb_kwargs
        )
        
        sample_weight = None
        if use_class_weight and is_classification:
            from sklearn.utils.class_weight import compute_sample_weight
            sample_weight = compute_sample_weight('balanced', y_train)

        eval_set = [(X_val, y_val)]

        callbacks = [lgb.early_stopping(stopping_rounds=patience, verbose=False)]
        self._model.fit(
            X_train,
            y_train,
            sample_weight=sample_weight,
            eval_set=eval_set,
            eval_metric='logloss',
            callbacks=callbacks
        )

        metrics = evaluate_model(X_val, y_val, X_test, y_test, is_classification, is_binary)
        
        # Prepare stats
        stats = {
            'algorithm': 'LightGBM',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'n_estimators': self._model.n_estimators,
                'max_depth': self._model.max_depth,
                'num_leaves': self._model.num_leaves,
                'learning_rate': self._model.learning_rate,
                'boosting_type': self._model.boosting_type,
                'feature_importances': self._model.feature_importances_.tolist() if hasattr(self._model, 'feature_importances_') else None
            }
        }

        model_path = output / 'best_model.pkl'
        import joblib
        joblib.dump(self._model, str(model_path))
        lib.dump_json(stats, output / 'stats.json', indent=4)
        
        return stats
    
def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='Run LightGBM Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = LightGBMClassifierModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )

    print(f"\nTraining completed. Results saved to: {output_dir / 'stats.json'}")
    print(stats)

if __name__ == "__main__":
    main()
    import gc
    try:
        torch.cuda.empty_cache()
    except ImportError:
        pass
    gc.collect()
