import numpy as np
import pandas as pd
from pathlib import Path
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['PYTHONPATH'] = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

import warnings
warnings.filterwarnings('ignore', category=pd.errors.SettingWithCopyWarning)

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from nam.wrapper import NAMClassifier
from sklearn.multiclass import OneVsRestClassifier

__all__ = ["NAMModule"]

import pandas as pd
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from lib.metrics import evaluate_model

class NAMBinaryAdapter(BaseEstimator, ClassifierMixin):
    def __init__(self, 
                 num_epochs=100,
                 batch_size=256,
                 lr=0.01,
                 patience=20,
                 val_split=0.15,
                 monitor_loss=True,
                 early_stop_mode='min',
                 metric='accuracy',
                 device='cpu',
                 random_state=42,
                 **kwargs):

        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.patience = patience
        self.val_split = val_split
        self.monitor_loss = monitor_loss
        self.early_stop_mode = early_stop_mode
        self.metric = metric
        self.device = device
        self.random_state = random_state
        
        self._model = None
        self.classes_ = None

    def fit(self, X, y):
        y = np.asarray(y).ravel()
        classes = np.unique(y)

        self.classes_ = classes

        nam_config = {
            'num_epochs': self.num_epochs,
            'lr': self.lr,
            'batch_size': self.batch_size,
            'random_state': self.random_state
        }
            
        print(f"Creating NAMClassifier with: {nam_config}")
        self._model = NAMClassifier(**nam_config)
        self._model.fit(X, y)          
          
        return self

    def predict_proba(self, X):

        raw_pred = self._model.predict_proba(X)
        raw_pred = np.asarray(raw_pred)

        if raw_pred.ndim == 1:
            pos_proba = raw_pred.ravel()
            neg_proba = 1.0 - pos_proba
            proba = np.column_stack([neg_proba, pos_proba])
        elif raw_pred.ndim == 2:
            if raw_pred.shape[1] == 1:
                pos_proba = raw_pred.ravel()
                neg_proba = 1.0 - pos_proba
                proba = np.column_stack([neg_proba, pos_proba])
            elif raw_pred.shape[1] == 2:
                proba = raw_pred
            else:
                proba = raw_pred[:, :2]
        else:
            raise ValueError(f"Unexpected prediction shape: {raw_pred.shape}")
        
        row_sums = proba.sum(axis=1, keepdims=True)
        proba = proba / (row_sums + 1e-10)
        
        return proba

    def predict(self, X):
        proba = self.predict_proba(X)
        return self.classes_[np.argmax(proba, axis=1)]

    def score(self, X, y):
        from sklearn.metrics import accuracy_score
        return accuracy_score(y, self.predict(X))
    
class NAMModule(DirectTunerMixin):
    def __init__(self, **kwargs) -> None:
        self.device = lib.get_device()
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None

    def create_model_from_params(self, params):
        return NAMModule(**params)

    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, config, trial_number, output_dir):
        timer = delu.Timer()
        
        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)
        
        dataset_id = config['data'].get('dataset_id')

       
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=False
        )
        
        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()

        val_ratio = len(X_val) / (len(X_train) + len(X_val))
        X_train = np.vstack([X_train, X_val])
        y_train = np.hstack([y_train, y_val])

        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']
        self.is_classification = self.num_classes > 1
        self.is_binary = self.num_classes == 2 
        training_config = config.get('training', {})
        model_config = config.get('model', {}) 
        training_config = config.get('training', {})
        model_config = config.get('model', {}) 
        lr = config.get('lr') or training_config.get('lr', 0.02082)
        n_epochs = config.get('n_epochs') or training_config.get('n_epochs', 1000)
        batch_size = config.get('batch_size') or training_config.get('batch_size', 1024)
        patience = config.get('patience') or training_config.get('patience', 60)
        hidden_sizes = config.get('hidden_sizes') or model_config.get('hidden_sizes', [64, 32])
        num_learners = config.get('num_learners') or model_config.get('num_learners', 1)  
        l2_reg = config.get('l2_reg') or config.get('l2_regularization') or model_config.get('l2_reg', 0.0)
        output_reg = config.get('output_reg') or model_config.get('output_reg', 0.2078)
        decay_rate = config.get('decay_rate') or model_config.get('decay_rate', 0.0)
        random_state = config.get('seed') or config.get('random_state') or training_config.get('random_state', 42)

        nam_config = {
            'num_epochs': n_epochs,
            'batch_size': batch_size,
            'lr': lr,
            'patience': patience,
            'val_split': val_ratio,
            'hidden_sizes': hidden_sizes,
            'num_learners': num_learners,
            'l2_reg': l2_reg,
            'output_reg': output_reg,
            'decay_rate': decay_rate,
            'random_state': random_state,
        }

        num_features = X_train.shape[1]
        num_classes = int(self.num_classes)

        if self.is_classification and num_classes > 2:
            base = NAMBinaryAdapter(**nam_config)
            self.model = OneVsRestClassifier(estimator=base, n_jobs=1)
        else:
            self.model = NAMClassifier(**nam_config)

        print(f"Training NAM with {num_features} features...")
        self.model.fit(X_train, y_train)
        metrics = evaluate_model(X_val, y_val, X_test, y_test)
        
        stats = {
            'algorithm': 'NAM',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer())
        }
        
        lib.dump_json(stats, output / 'stats.json', indent=4)
        
        return stats
    
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Run Neural Additive Model (NAM)')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()
    
    config = lib.load_toml(args.config_path)
    
    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.with_suffix('')
    else:
        output_dir = config_path.parent / 'trial_0'
    
    model = NAMModule()
    stats = model._train_and_evaluate(
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
    print(stats)
    
    print(f"\nTraining completed. Results saved to: {output_dir / 'stats.json'}")


if __name__ == "__main__":
    main()