import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Dict, Any

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier


@register_classifier('mean_pooling')
class MeanPoolingMILClassifier(BaseMILClassifier):
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)

        self.hidden_layers = config.get('hidden_layers', [256, 128])
        self.dropout_rate = config.get('dropout_rate', 0.3)
        self.activation = config.get('activation', 'relu')
        self.learning_rate = float(config.get('learning_rate', 0.001))
        self.weight_decay = float(config.get('weight_decay', 1e-4))
        self.batch_size = int(config.get('batch_size', 32))
        
    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        
        class MeanPoolingNet(nn.Module):
            def __init__(self, feature_dim, hidden_layers, n_classes, 
                        dropout_rate, activation):
                super().__init__()
                
                if activation == 'relu':
                    act_fn = nn.ReLU
                elif activation == 'tanh':
                    act_fn = nn.Tanh
                elif activation == 'sigmoid':
                    act_fn = nn.Sigmoid
                else:
                    act_fn = nn.ReLU
                
                layers = []
                input_dim = feature_dim
                
                for hidden_dim in hidden_layers:
                    layers.extend([
                        nn.Linear(input_dim, hidden_dim),
                        act_fn(),
                        nn.Dropout(dropout_rate)
                    ])
                    input_dim = hidden_dim

                layers.append(nn.Linear(input_dim, n_classes))
                
                self.classifier = nn.Sequential(*layers)
                
            def forward(self, x):
                # x: [batch_size, n_patches, feature_dim]
                pooled = torch.mean(x, dim=1)  # [batch_size, feature_dim]

                output = self.classifier(pooled)
                
                return output
        
        return MeanPoolingNet(
            feature_dim=feature_dim,
            hidden_layers=self.hidden_layers,
            n_classes=n_classes,
            dropout_rate=self.dropout_rate,
            activation=self.activation
        )
    
    def _pad_bags_to_same_length(self, bags: List[np.ndarray], 
                                max_patches: int = None) -> np.ndarray:
        if max_patches is None:
            max_patches = max(len(bag) for bag in bags)
        
        padded_bags = []
        for bag in bags:
            if len(bag) >= max_patches:
                padded_bag = bag[:max_patches]
            else:
                padding_needed = max_patches - len(bag)
                if len(bag) > 0:
                    last_patch = bag[-1:].repeat(padding_needed, axis=0)
                    padded_bag = np.vstack([bag, last_patch])
                else:
                    feature_dim = bags[0].shape[1] if len(bags) > 0 else 512
                    padded_bag = np.zeros((max_patches, feature_dim))
            
            padded_bags.append(padded_bag)
        
        return np.array(padded_bags)
    
    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]], 
                    labels: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:

        bag_features = [bag_features for bag_features, _ in bags]

        max_patches = min(1000, max(len(bag) for bag in bag_features)) 
        padded_bags = self._pad_bags_to_same_length(bag_features, max_patches)

        bags_tensor = torch.FloatTensor(padded_bags)
        labels_tensor = torch.LongTensor(labels)
        
        return bags_tensor, labels_tensor
    
    def train_epoch(self, train_data: torch.Tensor, train_labels: torch.Tensor,
                   val_data: torch.Tensor, val_labels: torch.Tensor, 
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer'):
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=self.weight_decay
            )
            self.criterion = nn.CrossEntropyLoss()

        self.model.train()

        dataset = torch.utils.data.TensorDataset(train_data, train_labels)
        dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True
        )
        
        train_losses = []
        train_correct = 0
        train_total = 0
        
        for batch_data, batch_labels in dataloader:
            loss_value, outputs, batch_labels = self._train_batch(
                batch_data, batch_labels, self.optimizer, self.criterion
            )

            train_losses.append(loss_value)

            _, predicted = torch.max(outputs.data, 1)
            train_total += batch_labels.size(0)
            train_correct += (predicted == batch_labels).sum().item()
        
        train_loss = np.mean(train_losses)
        train_acc = train_correct / train_total

        self.model.eval()
        with torch.no_grad():
            val_outputs = self.model(val_data)
            val_loss = self.criterion(val_outputs, val_labels).item()

            _, val_predicted = torch.max(val_outputs.data, 1)
            val_correct = (val_predicted == val_labels).sum().item()
            val_acc = val_correct / len(val_labels)
        
        return train_loss, train_acc, val_loss, val_acc
    
    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()

        test_data, _ = self.prepare_data(bags, [0] * len(bags)) 
        
        with torch.no_grad():
            test_data = self._to_device(test_data)

            outputs = self.model(test_data)
            probabilities = torch.softmax(outputs, dim=1).cpu().numpy()
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            confidence = np.max(probabilities, axis=1)
        
        return PredictionResult(
            predictions=predictions,
            probabilities=probabilities,
            confidence=confidence,
            bag_names=[] 
        )
