import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, Dict, Any

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier


@register_classifier('attention')
class AttentionMILClassifier(BaseMILClassifier):
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        
        self.attention_dim = config.get('attention_dim', 64)
        self.use_gated_attention = config.get('use_gated_attention', False)
        
        self.hidden_dim = config.get('hidden_dim', 256)
        self.dropout_rate = config.get('dropout_rate', 0.3)
        
        self.learning_rate = float(config.get('learning_rate', 0.001))
        self.weight_decay = float(config.get('weight_decay', 1e-4))
        self.batch_size = int(config.get('batch_size', 1))
        
    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        
        class AttentionMILNet(nn.Module):
            def __init__(self, feature_dim, attention_dim, hidden_dim, 
                        n_classes, dropout_rate, use_gated_attention):
                super().__init__()
                
                self.feature_dim = feature_dim
                self.attention_dim = attention_dim
                self.use_gated_attention = use_gated_attention
                
                self.feature_extractor = nn.Sequential(
                    nn.Linear(feature_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate)
                )
                
                if use_gated_attention:
                    self.attention_V = nn.Sequential(
                        nn.Linear(hidden_dim, attention_dim),
                        nn.Tanh()
                    )
                    self.attention_U = nn.Sequential(
                        nn.Linear(hidden_dim, attention_dim),
                        nn.Sigmoid()
                    )
                    self.attention_w = nn.Linear(attention_dim, 1)
                else:
                    self.attention = nn.Sequential(
                        nn.Linear(hidden_dim, attention_dim),
                        nn.Tanh(),
                        nn.Linear(attention_dim, 1)
                    )
                
                self.classifier = nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim // 2),
                    nn.ReLU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(hidden_dim // 2, n_classes)
                )
                
            def forward(self, x, return_attention=False):
                # x: [batch_size, n_patches, feature_dim]
                batch_size, n_patches, _ = x.size()
                
                h = self.feature_extractor(x.view(-1, self.feature_dim))
                h = h.view(batch_size, n_patches, -1)
                # h: [batch_size, n_patches, hidden_dim]

                if self.use_gated_attention:
                    A_V = self.attention_V(h)  # [batch_size, n_patches, attention_dim]
                    A_U = self.attention_U(h)  # [batch_size, n_patches, attention_dim]
                    A = self.attention_w(A_V * A_U)  # [batch_size, n_patches, 1]
                else:
                    A = self.attention(h)  # [batch_size, n_patches, 1]
                
                A = torch.softmax(A, dim=1)  # [batch_size, n_patches, 1]
                
                M = torch.sum(A * h, dim=1)  # [batch_size, hidden_dim]
                
                Y_prob = self.classifier(M)
                
                if return_attention:
                    return Y_prob, A.squeeze(-1)  # [batch_size, n_patches]
                else:
                    return Y_prob
        
        return AttentionMILNet(
            feature_dim=feature_dim,
            attention_dim=self.attention_dim,
            hidden_dim=self.hidden_dim,
            n_classes=n_classes,
            dropout_rate=self.dropout_rate,
            use_gated_attention=self.use_gated_attention
        )
    
    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]], 
                    labels: List[int]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        bag_tensors = []
        for bag_features, _ in bags:
            bag_tensor = torch.FloatTensor(bag_features)
            bag_tensors.append(bag_tensor)
        
        labels_tensor = torch.LongTensor(labels)
        
        return bag_tensors, labels_tensor
    
    def train_epoch(self, train_data: List[torch.Tensor], train_labels: torch.Tensor,
                   val_data: List[torch.Tensor], val_labels: torch.Tensor, 
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer'):
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=self.weight_decay
            )
            self.criterion = nn.CrossEntropyLoss()

        self.model.train()
        train_losses = []
        train_correct = 0

        for i, (bag_data, label) in enumerate(zip(train_data, train_labels)):
            self.optimizer.zero_grad()
            
            bag_data = bag_data.unsqueeze(0)  # [1, n_patches, feature_dim]
            output = self.model(bag_data)
            
            label = label.unsqueeze(0)  # [1]
            loss = self.criterion(output, label)
            
            loss.backward()
            self.optimizer.step()
            
            train_losses.append(loss.item())
            
            _, predicted = torch.max(output.data, 1)
            train_correct += (predicted == label).sum().item()
        
        train_loss = np.mean(train_losses)
        train_acc = train_correct / len(train_data)
        
        self.model.eval()
        val_losses = []
        val_correct = 0
        
        with torch.no_grad():
            for bag_data, label in zip(val_data, val_labels):
                bag_data = bag_data.unsqueeze(0)
                output = self.model(bag_data)
                
                label = label.unsqueeze(0)
                loss = self.criterion(output, label)
                val_losses.append(loss.item())
                
                _, predicted = torch.max(output.data, 1)
                val_correct += (predicted == label).sum().item()
        
        val_loss = np.mean(val_losses)
        val_acc = val_correct / len(val_data)
        
        return train_loss, train_acc, val_loss, val_acc
    
    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()
        
        predictions = []
        probabilities = []
        attention_weights = []
        
        with torch.no_grad():
            for bag_features, _ in bags:
                bag_tensor = torch.FloatTensor(bag_features).unsqueeze(0)
                output, attention = self.model(bag_tensor, return_attention=True)
                
                prob = torch.softmax(output, dim=1).squeeze(0).numpy()
                pred = torch.argmax(output, dim=1).squeeze(0).numpy()
                att = attention.squeeze(0).numpy()
                
                predictions.append(pred)
                probabilities.append(prob)
                attention_weights.append(att)
        
        predictions = np.array(predictions)
        probabilities = np.array(probabilities)
        confidence = np.max(probabilities, axis=1)
        
        return PredictionResult(
            predictions=predictions,
            probabilities=probabilities,
            confidence=confidence,
            bag_names=[] 
        )

    def _get_wsi_feature_for_bag(self, bag_features: np.ndarray) -> np.ndarray:
        self.model.eval()
        with torch.no_grad():
            bag_tensor = torch.FloatTensor(bag_features).unsqueeze(0)

            batch_size, n_patches, _ = bag_tensor.size()
            h = self.model.feature_extractor(bag_tensor.view(-1, self.feature_dim))
            h = h.view(batch_size, n_patches, -1)

            if self.model.use_gated_attention:
                A_V = self.model.attention_V(h)
                A_U = self.model.attention_U(h)
                A = self.model.attention_w(A_V * A_U)
            else:
                A = self.model.attention(h)

            A = torch.softmax(A, dim=1)

            wsi_feature = torch.sum(A * h, dim=1).squeeze(0)

            return wsi_feature.cpu().numpy()
