"""
CLAM (Clustering-constrained Attention Multiple instance learning) 
Data-efficient and weakly supervised computational pathology on whole-slide images
Nature Biomedical Engineering, 2021
"""

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from typing import List, Tuple, Dict, Any

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier


class Attn_Net(nn.Module):
    """
    Attention Network without Gating (2 fc layers)
    Args:
        L: input feature dimension
        D: hidden layer dimension
        dropout: whether to use dropout (p = 0.25)
        n_classes: number of classes 
    """
    def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
        super(Attn_Net, self).__init__()
        self.module = [
            nn.Linear(L, D),
            nn.Tanh()]

        if dropout:
            self.module.append(nn.Dropout(0.25))

        self.module.append(nn.Linear(D, n_classes))
        
        self.module = nn.Sequential(*self.module)
    
    def forward(self, x):
        return self.module(x), x  # N x n_classes


class Attn_Net_Gated(nn.Module):
    """
    Attention Network with Sigmoid Gating (3 fc layers)
    Args:
        L: input feature dimension
        D: hidden layer dimension
        dropout: whether to use dropout (p = 0.25)
        n_classes: number of classes 
    """
    def __init__(self, L=1024, D=256, dropout=False, n_classes=1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()]
        
        self.attention_b = [nn.Linear(L, D), nn.Sigmoid()]
        if dropout:
            self.attention_a.append(nn.Dropout(0.25))
            self.attention_b.append(nn.Dropout(0.25))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        return A, x


class CLAM_SB(nn.Module):
    """
    CLAM Single Branch
    Args:
        gate: whether to use gated attention network
        size_arg: size of the model (small or big)
        dropout: dropout rate
        k_sample: number of positive/negative patches to sample for instance-level training
        n_classes: number of classes
        instance_loss_fn: loss function to supervise instance-level training
        subtyping: whether it's a subtyping problem
        embed_dim: input feature dimension
    """
    def __init__(self, gate=True, size_arg="small", dropout=0., k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False, embed_dim=1024):
        super(CLAM_SB, self).__init__()
        self.size_dict = {"small": [embed_dim, 512, 256], "big": [embed_dim, 512, 384]}
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
        if gate:
            attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=1)
        else:
            attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        self.classifiers = nn.Linear(size[1], n_classes)
        instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping

    @staticmethod
    def create_positive_targets(length, device):
        return torch.full((length, ), 1, device=device).long()
    
    @staticmethod
    def create_negative_targets(length, device):
        return torch.full((length, ), 0, device=device).long()
    
    def inst_eval(self, A, h, classifier): 
        device = h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        top_p_ids = torch.topk(A, self.k_sample)[1][-1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        top_n_ids = torch.topk(-A, self.k_sample, dim=1)[1][-1]
        top_n = torch.index_select(h, dim=0, index=top_n_ids)
        p_targets = self.create_positive_targets(self.k_sample, device)
        n_targets = self.create_negative_targets(self.k_sample, device)

        all_targets = torch.cat([p_targets, n_targets], dim=0)
        all_instances = torch.cat([top_p, top_n], dim=0)
        logits = classifier(all_instances)
        all_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1)
        instance_loss = self.instance_loss_fn(logits, all_targets)
        return instance_loss, all_preds, all_targets
    
    def inst_eval_out(self, A, h, classifier):
        device = h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        top_p_ids = torch.topk(A, self.k_sample)[1][-1]
        top_p = torch.index_select(h, dim=0, index=top_p_ids)
        p_targets = self.create_negative_targets(self.k_sample, device)
        logits = classifier(top_p)
        p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1)
        instance_loss = self.instance_loss_fn(logits, p_targets)
        return instance_loss, p_preds, p_targets

    def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
        A, h = self.attention_net(h)  # NxK        
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # softmax over N

        if instance_eval:
            total_inst_loss = 0.0
            all_preds = []
            all_targets = []
            inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze()
            for i in range(len(self.instance_classifiers)):
                inst_label = inst_labels[i].item()
                classifier = self.instance_classifiers[i]
                if inst_label == 1:  # in-the-class:
                    instance_loss, preds, targets = self.inst_eval(A, h, classifier)
                    all_preds.extend(preds.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                else:  # out-of-the-class
                    if self.subtyping:
                        instance_loss, preds, targets = self.inst_eval_out(A, h, classifier)
                        all_preds.extend(preds.cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())
                    else:
                        continue
                total_inst_loss += instance_loss

            if self.subtyping:
                total_inst_loss /= len(self.instance_classifiers)
                
        M = torch.mm(A, h) 
        logits = self.classifiers(M)
        Y_hat = torch.topk(logits, 1, dim=1)[1]
        Y_prob = F.softmax(logits, dim=1)
        
        if instance_eval:
            results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets), 
            'inst_preds': np.array(all_preds)}
        else:
            results_dict = {}
        if return_features:
            results_dict.update({'features': M})
        return logits, Y_prob, Y_hat, A_raw, results_dict


class CLAM_MB(CLAM_SB):
    """
    CLAM Multi Branch
    """
    def __init__(self, gate=True, size_arg="small", dropout=0., k_sample=8, n_classes=2,
        instance_loss_fn=nn.CrossEntropyLoss(), subtyping=False, embed_dim=1024):
        nn.Module.__init__(self)
        self.size_dict = {"small": [embed_dim, 512, 256], "big": [embed_dim, 512, 384]}
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)]
        if gate:
            attention_net = Attn_Net_Gated(L=size[1], D=size[2], dropout=dropout, n_classes=n_classes)
        else:
            attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=n_classes)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        bag_classifiers = [nn.Linear(size[1], 1) for i in range(n_classes)]
        self.classifiers = nn.ModuleList(bag_classifiers)
        instance_classifiers = [nn.Linear(size[1], 2) for i in range(n_classes)]
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping

    def forward(self, h, label=None, instance_eval=False, return_features=False, attention_only=False):
        A, h = self.attention_net(h)  # NxK        
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # softmax over N

        if instance_eval:
            total_inst_loss = 0.0
            all_preds = []
            all_targets = []
            inst_labels = F.one_hot(label, num_classes=self.n_classes).squeeze()
            for i in range(len(self.instance_classifiers)):
                inst_label = inst_labels[i].item()
                classifier = self.instance_classifiers[i]
                if inst_label == 1:  # in-the-class:
                    instance_loss, preds, targets = self.inst_eval(A[i], h, classifier)
                    all_preds.extend(preds.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                else:  # out-of-the-class
                    if self.subtyping:
                        instance_loss, preds, targets = self.inst_eval_out(A[i], h, classifier)
                        all_preds.extend(preds.cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())
                    else:
                        continue
                total_inst_loss += instance_loss

            if self.subtyping:
                total_inst_loss /= len(self.instance_classifiers)

        M = torch.mm(A, h) 

        logits = torch.empty(1, self.n_classes).float().to(M.device)
        for c in range(self.n_classes):
            logits[0, c] = self.classifiers[c](M[c])

        Y_hat = torch.topk(logits, 1, dim=1)[1]
        Y_prob = F.softmax(logits, dim=1)
        if instance_eval:
            results_dict = {'instance_loss': total_inst_loss, 'inst_labels': np.array(all_targets), 
            'inst_preds': np.array(all_preds)}
        else:
            results_dict = {}
        if return_features:
            results_dict.update({'features': M})
        return logits, Y_prob, Y_hat, A_raw, results_dict


@register_classifier('clam_sb')
class CLAMSBClassifier(BaseMILClassifier):

    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)

        self.gate = config.get('gate', True) 
        self.size_arg = config.get('size_arg', "small") 
        self.dropout = config.get('dropout', 0.0)
        self.k_sample = config.get('k_sample', 8) 
        self.subtyping = config.get('subtyping', False)  

        self.learning_rate = float(config.get('learning_rate', 2e-4))
        self.weight_decay = float(config.get('weight_decay', 1e-5))
        self.batch_size = int(config.get('batch_size', 1))

        self.bag_weight = config.get('bag_weight', 0.7) 
        self.instance_weight = config.get('instance_weight', 0.3)

    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        return CLAM_SB(
            gate=self.gate,
            size_arg=self.size_arg,
            dropout=self.dropout,
            k_sample=self.k_sample,
            n_classes=n_classes,
            instance_loss_fn=nn.CrossEntropyLoss(),
            subtyping=self.subtyping,
            embed_dim=feature_dim
        )

    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]],
                    labels: List[int]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        bag_tensors = []
        for bag_features, _ in bags:
            bag_tensor = torch.FloatTensor(bag_features)
            bag_tensors.append(bag_tensor)

        labels_tensor = torch.LongTensor(labels)
        return bag_tensors, labels_tensor

    def train_epoch(self, train_data: List[torch.Tensor], train_labels: torch.Tensor,
                   val_data: List[torch.Tensor], val_labels: torch.Tensor,
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer'):
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                weight_decay=self.weight_decay
            )
            self.criterion = nn.CrossEntropyLoss()

        self.model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for bag_data, label in zip(train_data, train_labels):
            if torch.cuda.is_available():
                bag_data = bag_data.cuda()
                label = label.cuda()

            self.optimizer.zero_grad()

            logits, Y_prob, Y_hat, A_raw, results_dict = self.model(
                bag_data, label=label, instance_eval=True
            )

            bag_loss = self.criterion(logits, label.unsqueeze(0))

            total_loss = bag_loss
            if 'instance_loss' in results_dict:
                instance_loss = results_dict['instance_loss']
                total_loss = (self.bag_weight * bag_loss +
                            self.instance_weight * instance_loss)

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            train_loss += total_loss.item()
            pred = torch.argmax(logits, dim=1)
            train_correct += (pred == label.unsqueeze(0)).sum().item()
            train_total += 1

        self.model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for bag_data, label in zip(val_data, val_labels):
                if torch.cuda.is_available():
                    bag_data = bag_data.cuda()
                    label = label.cuda()

                logits, Y_prob, Y_hat, A_raw, results_dict = self.model(
                    bag_data, instance_eval=False
                )
                loss = self.criterion(logits, label.unsqueeze(0))

                val_loss += loss.item()
                pred = torch.argmax(logits, dim=1)
                val_correct += (pred == label.unsqueeze(0)).sum().item()
                val_total += 1

        avg_train_loss = train_loss / len(train_data)
        avg_train_acc = train_correct / train_total
        avg_val_loss = val_loss / len(val_data) if len(val_data) > 0 else 0.0
        avg_val_acc = val_correct / val_total if val_total > 0 else 0.0

        return avg_train_loss, avg_train_acc, avg_val_loss, avg_val_acc

    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()
        predictions = []
        probabilities = []
        confidences = []
        bag_names = []

        with torch.no_grad():
            for bag_features, bag_name in bags:
                if isinstance(bag_features, np.ndarray):
                    bag_features = torch.FloatTensor(bag_features)

                if torch.cuda.is_available():
                    bag_features = bag_features.cuda()

                logits, Y_prob, Y_hat, A_raw, results_dict = self.model(
                    bag_features, instance_eval=False
                )

                predictions.append(Y_hat.cpu().numpy()[0])
                probabilities.append(Y_prob.cpu().numpy()[0])
                confidences.append(torch.max(Y_prob, dim=1)[0].cpu().numpy()[0])
                bag_names.append(bag_name)

        return PredictionResult(
            predictions=np.array(predictions),
            probabilities=np.array(probabilities),
            confidence=np.array(confidences),
            bag_names=bag_names
        )

    def _calculate_epoch_metrics(self, data, labels, phase):
        try:
            self.model.eval()
            with torch.no_grad():
                all_probs = []
                for bag_data in data:
                    if torch.cuda.is_available():
                        bag_data = bag_data.cuda()

                    _, Y_prob, _, _, _ = self.model(
                        bag_data, instance_eval=False
                    )
                    all_probs.append(Y_prob.cpu().numpy())

                y_prob = np.vstack(all_probs)

                if isinstance(labels, torch.Tensor):
                    y_true = labels.cpu().numpy()
                else:
                    y_true = labels

                y_pred = np.argmax(y_prob, axis=1)

                auc = self._calculate_auc(y_true, y_prob)
                f1 = self._calculate_f1(y_true, y_pred)

                return auc, f1

        except Exception as e:
            print(f"Warning: Cannot calculate {phase} metrics: {e}")
            return 0.0, 0.0


@register_classifier('clam_mb')
class CLAMMBClassifier(CLAMSBClassifier):

    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        return CLAM_MB(
            gate=self.gate,
            size_arg=self.size_arg,
            dropout=self.dropout,
            k_sample=self.k_sample,
            n_classes=n_classes,
            instance_loss_fn=nn.CrossEntropyLoss(),
            subtyping=self.subtyping,
            embed_dim=feature_dim
        )
