from tabpfn import TabPFNClassifier
from TabZilla.models.basemodel import BaseModel
import torch
import numpy as np

class TabPFNModel(BaseModel):
    def __init__(self, params, args):
        super().__init__(params, args)
        self.device = torch.device(args.device)
        
        # Get the parameter and pass it to TabPFNClassifier
        ignore_limits = params.get("ignore_pretraining_limits", True)
        self.model = TabPFNClassifier(
            device='cpu' if args.device == 'cpu' else 'cuda',
            ignore_pretraining_limits=ignore_limits
        )
        
    def fit(self, X, y, X_val=None, y_val=None):
        # Remove overwrite_warning parameter
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()
            y = y.cpu().numpy().astype(int)  # Ensure labels are integers
        
        # Add manual dataset size check
        if len(X) > 1000:
            X = X[:1000]
            y = y[:1000]
            print(f"Warning: Truncated to first 1000 samples (TabPFN limit)")
        
        self.model.fit(X, y)  # Removed problematic parameter
        return [], []
        
    def predict(self, X):
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()
        return self.model.predict(X)
    
    def predict_proba(self, X):
        if isinstance(X, torch.Tensor):
            X = X.cpu().numpy()
        proba = self.model.predict_proba(X)
        # Ensure 2D array even for binary classification
        if proba.ndim == 1:
            return np.vstack([1-proba, proba]).T
        return proba

