import os
import sys
import tomlkit 
import torch
import numpy as np
from pathlib import Path
from pytorch_tabnet.tab_model import TabNetClassifier
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model

__all__ = ["TabNetModule"]

class TabNetModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        lr: float = 2e-2,
        n_epochs: int = 100,
        batch_size: int = 1024,
        virtual_batch_size: int = 128,
        n_d: int = 8,
        n_a: int = 8,
        n_steps: int = 3,
        gamma: float = 1.3,
        n_independent: int = 2,
        n_shared: int = 2,
        lambda_sparse: float = 1e-3,
        momentum: float = 0.02,
        clip_value: float = 1.0,
        patience: int = 15,
        use_class_weight: bool = False,
        seed: int = 0,
        **kwargs
    ) -> None:
        self.lr = lr
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.virtual_batch_size = virtual_batch_size
        self.n_d = n_d
        self.n_a = n_a
        self.n_steps = n_steps
        self.gamma = gamma
        self.n_independent = n_independent
        self.n_shared = n_shared
        self.lambda_sparse = lambda_sparse
        self.momentum = momentum
        self.clip_value = clip_value
        self.patience = patience
        self.use_class_weight = use_class_weight
        self.seed = seed
        
        self.device = lib.get_device()
        self.model = None
        self.is_classification = None
        self.is_binary = None

    def create_model_from_params(self, params):
        return TabNetModule(**params)

    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)

        dataset_id = config['data'].get('dataset_id')

        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device,
            return_tensors=True
        )
        
        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']
        y_train = data['y_train'].ravel()
        y_val = data['y_val'].ravel()
        y_test = data['y_test'].ravel()
        
        num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        self.is_classification = num_classes > 1
        self.is_binary = num_classes == 2

        training_config = config.get('training', {})
        model_config = config.get('model', {})

        lr = config.get('lr') or training_config.get('lr', 1e-3)
        n_epochs = config.get('n_epochs') or config.get('epochs') or training_config.get('n_epochs', 100)
        batch_size = config.get('batch_size') or training_config.get('batch_size', 1024)
        virtual_batch_size = config.get('virtual_batch_size') or model_config.get('virtual_batch_size', 128)
        patience = config.get('patience') or training_config.get('patience', 15)
        use_class_weight = config.get('use_class_weight') or training_config.get('use_class_weight', False)

        n_d = config.get('n_d') or model_config.get('n_d', 8)
        n_steps = config.get('n_steps') or model_config.get('n_steps', 3)
        gamma = config.get('gamma') or model_config.get('gamma', 1.3)
        n_independent = config.get('n_independent') or model_config.get('n_independent', 2)
        n_shared = config.get('n_shared') or model_config.get('n_shared', 2)
        lambda_sparse = config.get('lambda_sparse') or model_config.get('lambda_sparse', 1e-3)

        if self.is_classification:
            self.model = TabNetClassifier(
                n_d=n_d,
                n_a=n_d,
                n_steps=n_steps,
                gamma=gamma,
                n_independent=n_independent,
                n_shared=n_shared,
                lambda_sparse=lambda_sparse,
                optimizer_fn=torch.optim.Adam,
                optimizer_params=dict(lr=lr),
                mask_type='entmax',
                scheduler_params={"step_size": 10, "gamma": 0.9},
                scheduler_fn=torch.optim.lr_scheduler.StepLR,
                verbose=1,
                seed=config.get('seed', 0),
                device_name='cuda' if torch.cuda.is_available() else 'cpu'
            )

        weights = None
        if use_class_weight and self.is_classification:
            from sklearn.utils.class_weight import compute_class_weight
            class_weights = compute_class_weight(
                'balanced',
                classes=np.unique(y_train),
                y=y_train
            )
            weights = dict(zip(np.unique(y_train), class_weights))


        self.model.fit(
            X_train=X_train,
            y_train=y_train,
            eval_set=[(X_val, y_val)],
            eval_name=['val'],
            eval_metric=['auc' if self.is_binary else 'accuracy'],
            max_epochs=n_epochs,
            patience=patience,
            batch_size=batch_size,
            virtual_batch_size=virtual_batch_size,
            num_workers=0,
            drop_last=False,
            weights=weights if weights else 0,  
        )
        
        metrics = evaluate_model(X_val, y_val, X_test, y_test)
        
        stats = {
            'algorithm': 'TabNet',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer())
        }
        
        model_path = output / 'best_model'
        self.model.save_model(str(model_path))
        lib.dump_json(stats, output / 'stats.json', indent=4)
        
        return stats

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Run TabNet Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'): 
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")
    
    model = TabNetModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
    
    print("Training completed. Stats:")
    print(stats)

if __name__ == "__main__":
    main()
