import torch as th
import numpy as np
from torch.optim import Adam, SGD
from typing import Optional


class Baseline(th.nn.Module):
    """A baseline classifier that uses only the feature extractor for supervised learning.
    
    This baseline skips the kernel computation entirely and directly performs classification
    using the feature extractor followed by a linear classifier. It's designed to be compatible
    with the existing training infrastructure while providing a non-kernel baseline.
    """
    
    def __init__(self, inputs: int, num_classes: int = 2, hidden_dims: list[int] = [], 
                 optimizer: str = "Adam", lr: float = 0.01, 
                 verbose: bool = True, **kwargs):
        """Baseline Classifier Args:
        inputs (int): input dimension
        num_classes (int): number of output classes, defaults to 2
        hidden_dims (list): hidden dimensions for preprocessing network, defaults to []
        compression (float): compression factor for feature extractor, defaults to 1.0
        optimizer (str): optimizer for training [Adam, SGD], defaults to Adam
        lr (float): learning rate for training, defaults to 0.01
        verbose (bool): print training progress, defaults to True
        """
        super().__init__()
        self.inputs = inputs
        self.num_classes = num_classes
        self.verbose = verbose
        self.calls = 0
        
      
        fe_dims = [self.inputs, *hidden_dims, self.num_classes]
        self.feature_extractor = th.nn.Sequential(*[
            layer for i, o in zip(fe_dims[:-1], fe_dims[1:]) 
            for layer in [th.nn.Linear(i, o, dtype=th.float64), th.nn.Tanh()]
        ][:-1])  # Remove last activation
        
        # Initialize optimizer
        self.optimizer = eval(optimizer)(self.parameters(), lr=lr)

    
    def forward(self, X: th.tensor) -> th.tensor:
        """Forward pass through feature extractor and classifier."""
        self.calls += 1
        return self.feature_extractor(X)
    
    
    def train(self, X: th.tensor, Y: th.tensor, loss: callable, epochs: int = 100,
              batch_size: Optional[int] = None, callback: Optional[callable] = None,
              log_interval: int = 20):
        """Train the baseline classifier using the same interface as kernels.
        
        Args:
            X (th.tensor): input data
            Y (th.tensor): target labels  
            loss (callable): loss function (not used, we use CrossEntropy)
            epochs (int): number of training epochs
            batch_size (int): batch size for training
            callback (callable): callback function for training progress
            log_interval (int): log interval for training progress
        """
        # Convert labels to proper format
        if Y.dtype != th.long: Y = Y.long()
        
              
        losses = []
        
        for ep in range(epochs):
            self.optimizer.zero_grad()
            
            if batch_size is not None:
                batch_idx = np.random.choice(list(range(X.size(0))), batch_size)
                X_batch, Y_batch = X[batch_idx], Y[batch_idx]
            else:
                X_batch, Y_batch = X, Y
            
            # Forward pass
            logits = self.forward(X_batch)
            loss_val = loss(logits, Y_batch)
            
            # Backward pass
            loss_val.backward()
            self.optimizer.step()
            
            # Store loss as tensor for compatibility with kernel training
            losses.append(loss_val.detach())
            
            if (ep + 1) % log_interval == 0:
                if callback is not None:
                    callback(self, loss_val, ep + 1)
                if self.verbose:
                    accuracy = self._compute_accuracy(X, Y)
                    print(f"Loss {loss_val:.6f}, Accuracy {accuracy:.3f} @Step {ep + 1}")
        
        return losses
    
    def predict(self, X: th.tensor) -> np.ndarray:
        """Predict classes for input data."""
        with th.no_grad():
            if isinstance(X, np.ndarray):
                X = th.tensor(X, dtype=th.float64)
            logits = self.forward(X)
            predictions = th.argmax(logits, dim=1)
        return predictions.numpy()
  
    
    def _compute_accuracy(self, X: th.tensor, y: th.tensor) -> float:
        """Compute accuracy on given data."""
        predictions = self.predict(X)
        return np.mean(predictions == y.numpy())


if __name__ == '__main__':
    from config import *
    from run.train import execute, fit, train
    from data import *

    seed = 42
    np.random.seed(seed); th.manual_seed(seed)

    config = {**CIFAR10}
    D = make(eval(config['dataset']), **config['data_kwargs'])
    X_train, X_test, y_train, y_test = D
    model = Baseline(inputs=D[0].size(1), num_classes=10, hidden_dims=[256], lr=1e-4, verbose=True)
    # Test baseline on different datasets
  

    losses = model.train(X_train, y_train, th.nn.CrossEntropyLoss(), epochs=100)
    # print(losses)

    accuracy = model._compute_accuracy(X_test, y_test)
    print(f"Final Test Accuracy on CIFAR10: {accuracy:.3f}")
    