from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_config, set_random_seed

import time
import math
from collections import deque
import torch
import torch.nn.functional as F


def _prepare_inputs(images, target, args):
    """Prepare images and target for GPU, handling AugMix/multiple views."""
    if isinstance(images, list):
        image_for_eval = images[0].cuda(args.gpu, non_blocking=True)
        if args.shifter:
            image_for_feature_extraction = torch.cat([img.cuda(args.gpu, non_blocking=True) for img in images], dim=0)
        else:
            image_for_feature_extraction = images[0].cuda(args.gpu, non_blocking=True)
    else:
        if len(images.size()) > 4:
            assert images.size()[0] == 1, "Batch size must be 1 if image has >4 dimensions."
            images = images.squeeze(0)
        image_for_eval = images.cuda(args.gpu, non_blocking=True)
        image_for_feature_extraction = images.cuda(args.gpu, non_blocking=True)

    target = target.cuda(args.gpu, non_blocking=True)
    return image_for_eval, image_for_feature_extraction, target


def create_C1_tensor(num_classes: int, initial_value: float = 30000) -> torch.Tensor:
    """Create a tensor of size (num_classes,) initialized with a constant value."""
    return torch.full((num_classes,), fill_value=initial_value, dtype=torch.float)


def test_time_adapt_eval(val_loader, model, shifter, optimizer, scaler, args, num_samples):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)

    device = args.gpu
    num_dataset_classes = model.text_features.size(0)

    progress = ProgressMeter(len(val_loader), [batch_time, top1, top5], prefix='Test: ')
    model.eval()

    # --- Hyperparameters (with safe defaults) ---
    theta = getattr(args, 'theta', 0.3)
    k = getattr(args, 'k', max(1, int(3 * math.log(max(2, num_dataset_classes)))))  # heuristic: k = 3*log(C)
    rel_window = getattr(args, 'rel_window', getattr(args, 'L', 5))
    A_window = getattr(args, 'A_window', rel_window)  # reuse same window length if not specified
    sigma_max = getattr(args, 'sigma_max', 0.5)
    clamp_negative = getattr(args, 'clamp_negative', True)  # whether to clamp negative similarity values to zero

    # --- Trackers ---
    adapted_text_features_tracker = model.text_features.clone().detach().to(device)
    adapted_text_features_tracker = F.normalize(adapted_text_features_tracker, dim=-1)
    adapted_text_features_tracker.requires_grad_(False)

    # per-class prototype update counter (for batch-wise weighted moving average)
    proto_counts = torch.zeros(num_dataset_classes, dtype=torch.long, device=device)

    # per-class reliability buffers (store confidence scores)
    rel_buffers = [deque(maxlen=rel_window) for _ in range(num_dataset_classes)]

    # adjacency matrix sliding window
    A_buffer = deque(maxlen=A_window)

    end = time.time()
    samples_processed = 0

    # ===== Helper functions =====
    def build_graph_from_prototypes_and_reliability(text_feats, mu, sigma):
        """Return adjacency matrix A (row-normalized, top-k) and weighted similarity W (for visualization/debug)."""
        # S: cosine similarity (diag=0)
        T = F.normalize(text_feats, dim=-1)  # [C, D]
        S = T @ T.t()                        # [C, C]
        S.fill_diagonal_(0.0)
        if clamp_negative:
            S = torch.clamp(S, min=0.0)

        # R: class-wise reliability
        R = mu * (1.0 - sigma / sigma_max)   # [C]
        R = torch.clamp(R, 0.0, 1.0)

        # Joint reliability
        R_joint = torch.outer(R, R)          # [C, C]

        # W = S ⊙ R_joint, diag=0
        W = S * R_joint
        W.fill_diagonal_(0.0)

        # Top-k row-wise & row-normalization to get A
        topk_vals, topk_idx = torch.topk(W, k=min(k, W.size(1)-1), dim=1)  # [C, k]
        A = torch.zeros_like(W)
        A.scatter_(1, topk_idx, topk_vals)

        # Normalize rows (avoid divide-by-zero)
        row_sums = A.sum(dim=1, keepdim=True).clamp_min(1e-12)
        A = A / row_sums
        return A, W

    def compute_mu_sigma_from_buffers(buffers):
        """Compute mean and std of confidence values for each class from buffers."""
        mu = torch.zeros(num_dataset_classes, device=device)
        sigma = torch.zeros(num_dataset_classes, device=device)
        for cls in range(num_dataset_classes):
            if len(buffers[cls]) == 0:
                # cold start: assume (1.0, 0.0)
                mu[cls] = 1.0
                sigma[cls] = 0.0
            else:
                vals = torch.tensor(list(buffers[cls]), dtype=torch.float32, device=device)
                mu[cls] = vals.mean()
                sigma[cls] = vals.std(unbiased=False)  # biased std is sufficient here
        return mu, sigma

    # ========= Main loop =========
    for i, (images, target) in enumerate(val_loader):
        assert args.gpu is not None, "GPU is not specified. Please set args.gpu."
        image_for_eval, image_for_feature_extraction, target = _prepare_inputs(images, target, args)
        samples_processed += image_for_eval.size(0)

        with torch.no_grad():
            # 1) Build/update graph (based on prototypes + reliability statistics)
            mu, sigma = compute_mu_sigma_from_buffers(rel_buffers)
            A_now, _ = build_graph_from_prototypes_and_reliability(adapted_text_features_tracker, mu, sigma)

            # Sliding-window averaging of adjacency
            A_buffer.append(A_now)
            barA = torch.stack(list(A_buffer), dim=0).mean(dim=0)  # [C, C]

            # 2) Forward pass: get raw logits and probabilities
            image_feats = model.encode_image(image_for_eval)                # [B, D]
            logits_raw = model.get_logits(image_feats, adapted_text_features_tracker)  # [B, C]
            p = torch.softmax(logits_raw, dim=-1)                           # [B, C]

            # 3) Graph propagation: p_graph(i) = sum_j barA[j,i]*p(j) / sum_j barA[j,i]
            denom = barA.sum(dim=0, keepdim=True).clamp_min(1e-12)          # [1, C], column sums (incoming weights)
            p_graph = (p @ barA) / denom                                    # [B, C]

            # 4) Fusion: \hat{p} ∝ p + p_graph
            hat_p = p + p_graph
            hat_p = hat_p / hat_p.sum(dim=1, keepdim=True).clamp_min(1e-12)

            # --- scores for evaluation ---
            scores_for_eval = hat_p  # if accuracy() requires logits, use torch.log(hat_p.clamp_min(1e-12))

            # 5) Pseudo-labels + high-confidence updates (prototypes & reliability)
            conf, y_star = torch.max(hat_p, dim=1)  # [B], [B]
            high_conf_mask = conf >= theta
            if high_conf_mask.any():
                # Batch-wise aggregated update (equivalent to sequential updates)
                # Use image_for_feature_extraction for more stable prototype updates
                feats_update = model.encode_image(image_for_feature_extraction)  # [B, D]
                feats_update = F.normalize(feats_update, dim=-1)

                for cls in torch.unique(y_star[high_conf_mask]).tolist():
                    cls_mask = high_conf_mask & (y_star == cls)
                    n_new = int(cls_mask.sum().item())
                    if n_new == 0:
                        continue
                    f_sum = feats_update[cls_mask].sum(dim=0)                 # [D]
                    t_old = adapted_text_features_tracker[cls]                # [D]
                    n_old = proto_counts[cls].item()

                    # Weighted moving average update (equivalent to n_new sequential updates)
                    t_new = (n_old * t_old + f_sum) / max(1, (n_old + n_new))
                    t_new = F.normalize(t_new, dim=-1)
                    adapted_text_features_tracker[cls] = t_new
                    proto_counts[cls] = n_old + n_new

                    # Update reliability buffer (append confidences of this class)
                    for cval in conf[cls_mask].tolist():
                        rel_buffers[cls].append(float(cval))

        # === Update metrics ===
        acc1, acc5 = accuracy(scores_for_eval, target, topk=(1, 5))
        top1.update(acc1[0], image_for_eval.size(0))
        top5.update(acc5[0], image_for_eval.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            print('\n')
            progress.display(i)

    progress.display_summary()
    return [top1.avg, top5.avg]
