from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any, Optional
import numpy as np
import torch
import torch.nn as nn
from dataclasses import dataclass
from sklearn.metrics import roc_auc_score, f1_score


@dataclass
class TrainingResult:
    model: nn.Module
    history: Dict[str, List[float]]
    final_accuracy: float
    final_auc: float
    final_f1: float
    best_epoch: int
    training_time: float


@dataclass
class PredictionResult:
    predictions: np.ndarray  
    confidence: np.ndarray
    bag_names: List[str] 


class BaseMILClassifier(ABC):
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.model = None
        self.is_trained = False

        device = config.get('device', 'auto')
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        print(f"Using device: {self.device}")

        self.feature_dim = None
        self.n_classes = None
        self.class_names = None
        self.wsi_features = {} 
    def _to_device(self, data):
        if isinstance(data, torch.Tensor):
            return data.to(self.device)
        elif isinstance(data, (list, tuple)):
            return type(data)(self._to_device(item) for item in data)
        else:
            return data

    def _train_batch(self, batch_data, batch_labels, optimizer, criterion):
        batch_data = self._to_device(batch_data)
        batch_labels = self._to_device(batch_labels)

        optimizer.zero_grad()

        outputs = self.model(batch_data)
        loss = criterion(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        return loss.item(), outputs, batch_labels

    @abstractmethod
    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        pass
    
    @abstractmethod
    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]], 
                    labels: List[int]) -> Tuple[Any, Any]:
        pass
    
    @abstractmethod
    def train_epoch(self, train_data: Any, train_labels: Any,
                   val_data: Any, val_labels: Any, 
                   epoch: int) -> Tuple[float, float, float, float]:
        pass
    
    @abstractmethod
    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        pass

    def _calculate_auc(self, y_true: np.ndarray, y_prob: np.ndarray) -> float:
        try:
            if self.n_classes == 2:
                return roc_auc_score(y_true, y_prob[:, 1])
            else:
                return roc_auc_score(y_true, y_prob, multi_class='ovr', average='macro')
        except Exception as e:
            print(f"Warning: Cannot calculate AUC: {e}")
            return 0.0

    def _calculate_f1(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
        try:
            if self.n_classes == 2:
                return f1_score(y_true, y_pred, average='binary')
            else:
                return f1_score(y_true, y_pred, average='macro')
        except Exception as e:
            print(f"Warning: Cannot calculate F1: {e}")
            return 0.0

    def _calculate_epoch_metrics(self, data, labels, data_type: str) -> tuple:

        self.model.eval()
        with torch.no_grad():
            if isinstance(data, list):
                all_probs = []
                for bag_data in data:
                    if len(bag_data.shape) == 2:
                        bag_data = bag_data.unsqueeze(0)
                    bag_data = bag_data.to(self.device)
                    output = self.model(bag_data)
                    prob = torch.softmax(output, dim=1)
                    all_probs.append(prob.cpu().numpy())
                y_prob = np.vstack(all_probs)
            else:
                data = self._to_device(data) 
                outputs = self.model(data)
                y_prob = torch.softmax(outputs, dim=1).cpu().numpy()

            if isinstance(labels, torch.Tensor):
                y_true = labels.cpu().numpy()
            else:
                y_true = labels

            y_pred = np.argmax(y_prob, axis=1)

            auc = self._calculate_auc(y_true, y_prob)
            f1 = self._calculate_f1(y_true, y_pred)

            return auc, f1


    def fit(self, train_bags: List[Tuple[np.ndarray, Any]],
            train_labels: List[int],
            val_bags: List[Tuple[np.ndarray, Any]], 
            val_labels: List[int],
            class_names: List[str]) -> TrainingResult:
        import time
        start_time = time.time()
        
        self.feature_dim = train_bags[0][0].shape[1]
        self.n_classes = len(set(train_labels))
        self.class_names = class_names
        
        self.model = self.build_model(self.feature_dim, self.n_classes)
        self.model = self.model.to(self.device)
        
        train_data, train_labels_processed = self.prepare_data(train_bags, train_labels)
        val_data, val_labels_processed = self.prepare_data(val_bags, val_labels)
        
        train_data = self._to_device(train_data)
        train_labels_processed = self._to_device(train_labels_processed)
        val_data = self._to_device(val_data)
        val_labels_processed = self._to_device(val_labels_processed)
        
        history = {
            'train_loss': [],
            'train_acc': [],
            'train_auc': [],
            'train_f1': [],
            'val_loss': [],
            'val_acc': [],
            'val_auc': [],
            'val_f1': [],
            'epochs': []
        }

        best_val_acc = 0.0
        best_val_auc = 0.0
        best_val_f1 = 0.0
        best_epoch = 0
        epochs = self.config.get('epochs', 100)
        
        print(f"Starting training {self.__class__.__name__}...")
        print(f"Feature dimension: {self.feature_dim}, Number of classes: {self.n_classes}")
        print(f"Training bags: {len(train_bags)}, Validation bags: {len(val_bags)}")
        
        for epoch in range(epochs):
            train_loss, train_acc, val_loss, val_acc = self.train_epoch(
                train_data, train_labels_processed,
                val_data, val_labels_processed,
                epoch
            )

            train_auc, train_f1 = self._calculate_epoch_metrics(train_data, train_labels_processed, 'train')
            val_auc, val_f1 = self._calculate_epoch_metrics(val_data, val_labels_processed, 'val')

            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['train_auc'].append(train_auc)
            history['train_f1'].append(train_f1)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_auc'].append(val_auc)
            history['val_f1'].append(val_f1)
            history['epochs'].append(epoch + 1)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_val_auc = val_auc
                best_val_f1 = val_f1
                best_epoch = epoch

            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{epochs}: "
                      f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, Train AUC={train_auc:.4f}, Train F1={train_f1:.4f}, "
                      f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, Val AUC={val_auc:.4f}, Val F1={val_f1:.4f}")
        
        training_time = time.time() - start_time
        self.is_trained = True
        self.training_history = history 

        print(f"Training completed! Best validation accuracy: {best_val_acc:.4f}, AUC: {best_val_auc:.4f}, F1: {best_val_f1:.4f} (Epoch {best_epoch+1})")
        print(f"Training time: {training_time:.2f} seconds")

        final_val_acc = history['val_acc'][-1] if history['val_acc'] else 0.0
        final_val_auc = history['val_auc'][-1] if history['val_auc'] else 0.0
        final_val_f1 = history['val_f1'][-1] if history['val_f1'] else 0.0

        return TrainingResult(
            model=self.model,
            history=history,
            final_accuracy=final_val_acc,  
            final_auc=final_val_auc,    
            final_f1=final_val_f1,     
            best_epoch=best_epoch,
            training_time=training_time
        )
    
    def predict(self, bags: List[Tuple[np.ndarray, Any]], 
                bag_names: Optional[List[str]] = None) -> PredictionResult:
        if not self.is_trained:
            raise RuntimeError("Model has not been trained yet, please call fit() method first")
        
        if bag_names is None:
            bag_names = [f"bag_{i}" for i in range(len(bags))]
        
        result = self.predict_bags(bags)
        result.bag_names = bag_names
        
        return result
    
    def get_config(self) -> Dict[str, Any]:
        return self.config.copy()
    
    def get_model_info(self) -> Dict[str, Any]:
        return {
            'classifier_type': self.__class__.__name__,
            'feature_dim': self.feature_dim,
            'n_classes': self.n_classes,
            'class_names': self.class_names,
            'is_trained': self.is_trained,
            'config': self.config
        }

    def get_wsi_features(self) -> Dict[str, np.ndarray]:
        return self.wsi_features.copy()

    def _extract_wsi_features(self, bags: List[Tuple[np.ndarray, Any]]) -> Dict[str, np.ndarray]:
        wsi_features = {}

        if not self.is_trained:
            return wsi_features

        self.model.eval()
        with torch.no_grad():
            for bag_features, metadata in bags:
                wsi_name = metadata if isinstance(metadata, str) else f"wsi_{len(wsi_features)}"

                if hasattr(self, '_get_wsi_feature_for_bag'):
                    wsi_feature = self._get_wsi_feature_for_bag(bag_features)
                else:
                    wsi_feature = np.mean(bag_features, axis=0)

                wsi_features[wsi_name] = wsi_feature

        return wsi_features
