import torch
from tqdm import tqdm

from mmseg.models.utils import resize


class UncerDetector:
    def __init__(self, model, checkpoint, lbd=0.5, temperature=1, **kwargs):
        self.model = model
        self.device = next(model.parameters()).device
        self.lbd = lbd
        self.temperature = temperature

        if checkpoint is not None:
            self.params = torch.load(checkpoint)
        else:
            self.params = None

    def __repr__(self):
        return f"Uncertainty Detector for {self.model.__class__.__name__}"

    def fit(self, train_dataloader, *args, **kwargs):
        total_probs_pos = []
        total_probs_neg = []

        for data in tqdm(train_dataloader, desc="Fitting metric"):
            data = self.model.data_preprocessor(data, True)
            inputs = data['inputs'].to(self.device)
            data_samples = data['data_samples']

            labels = torch.cat([sample.gt_sem_seg.data for sample in data_samples], dim=0)

            with torch.no_grad():
                logits = self.model(inputs, data_samples)  # N, C, H, W
                logits = resize(
                    input=logits,
                    size=labels.shape[-2:],
                    mode='bilinear',
                    align_corners=self.model.align_corners
                )

            train_pred = logits.argmax(dim=1)  # N, H, W
            train_labels = (labels != train_pred).int()  # N, H, W
            train_probs = torch.softmax(logits / self.temperature, dim=1)

            batch_probs_pos = (train_probs * (1 - train_labels.unsqueeze(1))).mean(dim=(2, 3)).to(self.device)  # N, C
            batch_probs_neg = (train_probs * train_labels.unsqueeze(1)).mean(dim=(2, 3)).to(self.device)  # N, C

            total_probs_pos.append(batch_probs_pos)
            total_probs_neg.append(batch_probs_neg)

        total_probs_pos = torch.cat(total_probs_pos, dim=0)  # N, C
        total_probs_neg = torch.cat(total_probs_neg, dim=0)  # N, C

        self.params = -(1 - self.lbd) * torch.einsum("ij,ik->ijk", total_probs_pos, total_probs_pos).mean(dim=0).to(
            self.device) \
                      + self.lbd * torch.einsum("ij,ik->ijk", total_probs_neg, total_probs_neg).mean(dim=0).to(
            self.device)

        self.params = torch.tril(self.params, diagonal=-1)
        self.params = self.params + self.params.T
        self.params = torch.relu(self.params)
        self.params = self.params / self.params.norm()

    def __call__(self, logits, *args, **kwds):
        probs = torch.softmax(logits / self.temperature, dim=1)  # N, C, H, W
        params = torch.tril(self.params, diagonal=-1)
        params = params + params.T
        params = params / params.norm()

        probs_mean = probs.mean(dim=(2, 3))  # N, C
        uncertainty = torch.diag(probs_mean @ params @ probs_mean.T)

        return uncertainty

    def calculate_uncertainty_map(self, logits, *args, **kwds):
        probs = torch.softmax(logits / self.temperature, dim=1)
        params = torch.tril(self.params, diagonal=-1)
        params = params + params.T
        params = params / params.norm()

        probs_reshaped = probs.view(probs.size(0), probs.size(1), -1)  # (B, C, H*W)
        # We will calculate p̂(x) D p̂(x)^⊤ for each pixel
        uncertainty_scores = torch.matmul(probs_reshaped.permute(0, 2, 1), params)  # (B, H*W, C)
        uncertainty_scores = torch.matmul(uncertainty_scores, probs_reshaped)  # (B, H*W, H*W)

        uncertainty_scores = torch.diagonal(uncertainty_scores, dim1=-2, dim2=-1)  # (B, H*W)

        uncertainty_scores = uncertainty_scores.view(probs.size(0), probs.size(2), probs.size(3))  # (B, H, W)

        return uncertainty_scores

    def fit_tta(self, device_logits, cloud_logits, alpha=0.1, *args, **kwargs):
        # Step 1: Generate pseudo labels and device predictions
        pseudo_labels = cloud_logits.argmax(dim=1)
        device_pred = device_logits.argmax(dim=1)

        test_labels = (pseudo_labels != device_pred).int()  # Areas of disagreement
        test_probs = torch.softmax(device_logits / self.temperature, dim=1)

        probs_pos = (test_probs * (1 - test_labels.unsqueeze(1))).mean(dim=(2, 3)).to(
            self.device)  # Class probabilities in agreement
        probs_neg = (test_probs * test_labels.unsqueeze(1)).mean(dim=(2, 3)).to(
            self.device)  # Class probabilities in disagreement

        param_new = -(1 - self.lbd) * torch.einsum("ij,ik->ijk", probs_pos, probs_pos).mean(dim=0).to(self.device) \
                    + self.lbd * torch.einsum("ij,ik->ijk", probs_neg, probs_neg).mean(dim=0).to(self.device)

        # param_new = (1 - self.lbd) * torch.einsum("ij,ik->ijk", probs_pos, probs_pos).mean(dim=0).to(self.device) \
        #             - self.lbd * torch.einsum("ij,ik->ijk", probs_neg, probs_neg).mean(dim=0).to(self.device)

        mis_label = (pseudo_labels != device_pred).int() * pseudo_labels  # Only classes with discrepancies
        class_mask = torch.zeros_like(self.params, device=self.device)  # (C, C)

        # Set the mask to 1 for classes that have discrepancies
        unique_classes = torch.unique(mis_label)
        for c in unique_classes:
            if c > 0:  # Ignore background (assuming 0 is background)
                class_mask[c, :] = 1
                class_mask[:, c] = 1

        self.params = (1 - alpha) * self.params + alpha * param_new.cuda() * class_mask.cuda()  # Weighted update with class masking
        self.params = torch.tril(self.params, diagonal=-1) + torch.tril(self.params, diagonal=-1).T  # Ensure symmetry
        self.params = torch.relu(self.params)  # Optional: ensure non-negativity
        self.params = self.params / self.params.norm()  # Normalize for stability




