"""
TransMIL (Transformer-based Multiple Instance Learning) 
"Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification"

"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from typing import List, Tuple, Dict, Any

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier

from nystrom_attention import NystromAttention


@register_classifier('transmil')
class TransMILClassifier(BaseMILClassifier):
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)

        self.input_dim = config.get('input_dim', 512)
        self.hidden_dim = config.get('hidden_dim', 512)
        self.dropout = config.get('dropout', 0.1)
        self.num_heads = config.get('num_heads', 8)
        self.num_landmarks = config.get('num_landmarks', 256)

        self.learning_rate = float(config.get('learning_rate', 2e-4))
        self.weight_decay = float(config.get('weight_decay', 1e-5))
        self.optimizer_type = config.get('optimizer_type', 'adamw')

    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:

        class TransLayer(nn.Module):
            def __init__(self, norm_layer=nn.LayerNorm, dim=512):
                super().__init__()
                self.norm = norm_layer(dim)
                self.attn = NystromAttention(
                    dim=dim,
                    dim_head=dim//8,
                    heads=8,
                    num_landmarks=dim//2,
                    pinv_iterations=6,
                    residual=True,
                    dropout=0.1
                )

            def forward(self, x):
                x = x + self.attn(self.norm(x))
                return x

        class PPEG(nn.Module):
            def __init__(self, dim=512):
                super(PPEG, self).__init__()
                self.proj = nn.Conv2d(dim, dim, 7, 1, 7//2, groups=dim)
                self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5//2, groups=dim)
                self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3//2, groups=dim)

            def forward(self, x, H, W):
                B, _, C = x.shape
                cls_token, feat_token = x[:, 0], x[:, 1:]
                cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
                x = self.proj(cnn_feat)+cnn_feat+self.proj1(cnn_feat)+self.proj2(cnn_feat)
                x = x.flatten(2).transpose(1, 2)
                x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
                return x

        class TransMIL(nn.Module):
            def __init__(self, input_dim=1024, n_classes=2):
                super(TransMIL, self).__init__()
                self.pos_layer = PPEG(dim=512)
                self._fc1 = nn.Sequential(nn.Linear(input_dim, 512), nn.ReLU())
                self.cls_token = nn.Parameter(torch.randn(1, 1, 512))
                self.n_classes = n_classes
                self.layer1 = TransLayer(dim=512)
                self.layer2 = TransLayer(dim=512)
                self.norm = nn.LayerNorm(512)
                self._fc2 = nn.Linear(512, self.n_classes)

            def forward(self, data):
                if data.dim() == 2:
                    data = data.unsqueeze(0)  # [1, n_patches, input_dim]

                h = data.float()  # [B, n, input_dim]
                h = self._fc1(h)  # [B, n, 512]

                # Padding
                H = h.shape[1]
                _H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
                add_length = _H * _W - H
                h = torch.cat([h, h[:,:add_length,:]],dim = 1)  # [B, N, 512]

                # CLS token
                B = h.shape[0]
                cls_tokens = self.cls_token.expand(B, -1, -1)
                if h.is_cuda:
                    cls_tokens = cls_tokens.cuda()
                h = torch.cat((cls_tokens, h), dim=1)

                # Transformer layers
                h = self.layer1(h)  # [B, N, 512]
                h = self.pos_layer(h, _H, _W)  # [B, N, 512]
                h = self.layer2(h)  # [B, N, 512]

                # Classification
                h = self.norm(h)[:,0]
                logits = self._fc2(h)  # [B, n_classes]

                return logits

        return TransMIL(input_dim=feature_dim, n_classes=n_classes)

    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]],
                    labels: List[int]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        bag_tensors = []
        for bag_features, _ in bags:
            bag_tensor = torch.FloatTensor(bag_features)
            bag_tensors.append(bag_tensor)

        labels_tensor = torch.LongTensor(labels)

        return bag_tensors, labels_tensor

    def train_epoch(self, train_data: List[torch.Tensor], train_labels: torch.Tensor,
                   val_data: List[torch.Tensor], val_labels: torch.Tensor,
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer'):
            if self.optimizer_type.lower() == 'adamw':
                self.optimizer = optim.AdamW(
                    self.model.parameters(),
                    lr=self.learning_rate,
                    weight_decay=self.weight_decay,
                    betas=(0.9, 0.999)
                )
            else:
                self.optimizer = optim.Adam(
                    self.model.parameters(),
                    lr=self.learning_rate,
                    weight_decay=self.weight_decay
                )
            self.criterion = nn.CrossEntropyLoss()

        self.model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for bag_data, label in zip(train_data, train_labels):
            if torch.cuda.is_available():
                bag_data = bag_data.cuda()
                label = label.cuda()

            self.optimizer.zero_grad()
            logits = self.model(bag_data)
            loss = self.criterion(logits, label.unsqueeze(0))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            train_loss += loss.item()
            pred = torch.argmax(logits, dim=1)
            train_correct += (pred == label.unsqueeze(0)).sum().item()
            train_total += 1

        self.model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for bag_data, label in zip(val_data, val_labels):
                if torch.cuda.is_available():
                    bag_data = bag_data.cuda()
                    label = label.cuda()

                logits = self.model(bag_data)
                loss = self.criterion(logits, label.unsqueeze(0))

                val_loss += loss.item()
                pred = torch.argmax(logits, dim=1)
                val_correct += (pred == label.unsqueeze(0)).sum().item()
                val_total += 1

        avg_train_loss = train_loss / len(train_data)
        avg_train_acc = train_correct / train_total
        avg_val_loss = val_loss / len(val_data) if len(val_data) > 0 else 0.0
        avg_val_acc = val_correct / val_total if val_total > 0 else 0.0

        return avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc

    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()
        predictions = []
        probabilities = []
        confidences = []
        bag_names = []

        with torch.no_grad():
            for bag_features, bag_name in bags:
                if isinstance(bag_features, np.ndarray):
                    bag_features = torch.FloatTensor(bag_features)

                if torch.cuda.is_available():
                    bag_features = bag_features.cuda()

                logits = self.model(bag_features)
                probs = torch.softmax(logits, dim=1)
                pred = torch.argmax(logits, dim=1)
                conf = torch.max(probs, dim=1)[0]

                predictions.append(pred.cpu().numpy()[0])
                probabilities.append(probs.cpu().numpy()[0])
                confidences.append(conf.cpu().numpy()[0])
                bag_names.append(bag_name)

        return PredictionResult(
            predictions=np.array(predictions),
            probabilities=np.array(probabilities),
            confidence=np.array(confidences),
            bag_names=bag_names
        )