import os
import sys
import tomlkit 
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from lib.metrics import evaluate_model_torch, compute_val_loss
from torch.utils.data import DataLoader, TensorDataset

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset

class MLP(nn.Module):
    def __init__(
        self,
        *,
        d_in: int,
        d_layers: list[int],
        d_out: int,
    ) -> None:
        super().__init__()
        layer_sizes = [d_in] + d_layers
        self.layers = nn.ModuleList([
            nn.Linear(layer_sizes[i], layer_sizes[i + 1])
            for i in range(len(d_layers))
        ])
        
        self.head = nn.Linear(d_layers[-1] if d_layers else d_in, d_out)
        print(f"Head layer: {self.head.in_features} -> {self.head.out_features}")

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = F.relu(x)
        x = self.head(x)
        x = x.squeeze(-1)
        return x

class MLPClassifierModule(DirectTunerMixin):
    def __init__(self, **params):
        self.params = params
        self.model = None
        self.device = lib.get_device()
    
    def create_model_from_params(self, params):
        return MLPClassifierModule(**params)
    
    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))
    
    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()
        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)

        dataset_id = config['data'].get('dataset_id')

        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device
        )

        X_train = data['X_train']
        X_val = data['X_val'] 
        X_test = data['X_test']
        y_train = data['y_train']
        y_val = data['y_val']
        y_test = data['y_test']
        
        input_dim = data['input_dim']
        num_classes = data['num_classes']
        dataset_info = data['dataset_info']

        is_classification = num_classes > 1
        is_binary = num_classes == 2
        
        # Model setup
        if is_classification:
            if is_binary:
                loss_fn = F.binary_cross_entropy_with_logits
                d_out = 1
            else:
                loss_fn = F.cross_entropy
                d_out = num_classes
        else:
            loss_fn = F.mse_loss
            d_out = 1
        
        training_config = config.get('training', {})
        model_config = config.get('model', {})

        lr = config.get('lr') or training_config.get('lr', 1e-3)
        batch_size = config.get('batch_size') or training_config.get('batch_size', 64)
        n_epochs = config.get('epochs') or training_config.get('n_epochs', training_config.get('max_epochs', 500))
        patience = config.get('patience') or training_config.get('patience', 20)
        weight_decay = config.get('weight_decay') or training_config.get('weight_decay', 1e-4)

        d_layers = config.get('d_layers') or model_config.get('d_layers', [128, 64])
        if isinstance(d_layers, dict):
            n_layers = d_layers.get('n_layers', 2)
            d_first = d_layers.get('d_first', 128)
            d_middle = d_layers.get('d_middle', 64)
            d_last = d_layers.get('d_last', 32)
            
            if n_layers == 1:
                d_layers = [d_first]
            elif n_layers == 2:
                d_layers = [d_first, d_last]
            else:
                d_layers = [d_first] + [d_middle] * (n_layers - 2) + [d_last]

        optimizer_type = config.get('optimizer_type') or training_config.get('optimizer_type', training_config.get('optimizer', 'AdamW'))
        optimizer_type = optimizer_type.lower()  
        self.model = MLP(
            d_in=input_dim,
            d_out=d_out,
            d_layers=d_layers,
        ).to(self.device)
        
        optimizer = lib.make_optimizer(
            optimizer_type,
            self.model.parameters(),
            lr,
            weight_decay,
        )

        train_dataset = TensorDataset(X_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        best_val_loss = float('inf')
        val_losses = []
        patience_counter = 0
        best_metrics = {}

        for epoch in range(n_epochs):
            # Train
            self.model.train()
            epoch_losses = []
            for batch_X, batch_y in train_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                if is_binary:
                    batch_y = batch_y.squeeze(-1).float() 
                elif is_classification and num_classes > 2:
                    batch_y = batch_y.view(-1).long()
                
                optimizer.zero_grad()
                loss = loss_fn(self.model(batch_X), batch_y)
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item())

            val_loss = compute_val_loss(
                self.model, X_val, y_val, loss_fn,
                is_binary, is_classification, num_classes, self.device
            )
            val_losses.append(val_loss)
            
            metrics = evaluate_model_torch(self.model, X_val, y_val, X_test, y_test, is_classification, is_binary, self.device)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                torch.save(self.model.state_dict(), output / 'best_model.pt')
                best_metrics = metrics
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
                

        if (output / 'best_model.pt').exists():
            self.model.load_state_dict(torch.load(output / 'best_model.pt'))
            final_metrics = evaluate_model_torch(self.model, X_val, y_val, X_test, y_test, is_classification, is_binary, self.device)
        else:
            final_metrics = best_metrics

        min_val_loss = min(val_losses) if val_losses else best_val_loss
        stats = {
            'algorithm': 'MLPClassifier',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': final_metrics,
            'validation_loss': min_val_loss,  
            'trial_number': trial_number,
            'time': lib.format_seconds(timer())  
        }
        lib.dump_json(stats, output / 'stats.json', indent=4)
        
        return stats


def main():
    import argparse
    parser = argparse.ArgumentParser(description='Run MLP Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    with open(args.config_path, 'r') as f:
        config = tomlkit.load(f)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.with_suffix('')
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = MLPClassifierModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )

    print(f"\nTraining completed. Results saved to: {output_dir / 'stats.json'}")
    try:
        for part in ['val', 'test']:
            if part in stats['metrics']:
                metrics_summary = stats['metrics'][part]
                print(f"[{part}] Score: {metrics_summary.get('score', 'N/A'):.4f}")
    except:
        print("Results:", stats)

if __name__ == "__main__":
    main()
