import argparse
from typing import Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score


class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, hidden), nn.ReLU(inplace=True),
            nn.Linear(hidden, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x).squeeze(1)


@torch.no_grad()
def _infer_feature_dim(model: nn.Module, sample: torch.Tensor) -> int:
    if isinstance(model, nn.DataParallel):
        model = model.module
    modules = list(model.children())[:-1]
    backbone = nn.Sequential(*modules)
    f = backbone(sample)
    f = torch.flatten(f, 1)
    return int(f.size(1))


def _extract_features(model: nn.Module, loader, device: str) -> torch.Tensor:
    if isinstance(model, nn.DataParallel):
        model = model.module
    modules = list(model.children())[:-1]
    backbone = nn.Sequential(*modules).to(device)
    backbone.eval()
    feats = []
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            f = backbone(images)
            f = torch.flatten(f, 1)
            feats.append(f.detach().cpu())
    if not feats:
        return torch.empty(0)
    return torch.cat(feats, dim=0)


def train_domain_classifier_and_iw(model: nn.Module, source_loader, target_loader, device: str = "cpu",
                                   max_steps: int = 2000, lr: float = 1e-3) -> Tuple[float, np.ndarray, float]:
    """Train a small domain classifier D on frozen features from model.

    Returns (AUC, importance_weights_per_source_val_sample, domain_error).
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # Collect features (on CPU to fit sklearn-friendly tools if needed)
    Xs = _extract_features(model, source_loader, device)
    Xt = _extract_features(model, target_loader, device)
    if Xs.numel() == 0 or Xt.numel() == 0:
        return float("nan"), np.array([])

    in_dim = Xs.size(1)
    D = MLP(in_dim).to(device)
    opt = optim.Adam(D.parameters(), lr=lr)
    bce = nn.BCEWithLogitsLoss()

    # Build tensors and labels
    ys = torch.zeros(Xs.size(0), dtype=torch.float32)
    yt = torch.ones(Xt.size(0), dtype=torch.float32)
    X = torch.cat([Xs, Xt], dim=0).to(device)
    y = torch.cat([ys, yt], dim=0).to(device)

    # Simple SGD over shuffled batches
    batch_size = 256
    num_steps = 0
    D.train()
    while num_steps < max_steps:
        idx = torch.randperm(X.size(0), device=device)
        for i in range(0, X.size(0), batch_size):
            j = min(i + batch_size, X.size(0))
            xb = X[idx[i:j]]
            yb = y[idx[i:j]]
            opt.zero_grad()
            logits = D(xb)
            loss = bce(logits, yb)
            loss.backward()
            opt.step()
            num_steps += 1
            if num_steps >= max_steps:
                break

    # AUC and error on entire set (approximate)
    D.eval()
    with torch.no_grad():
        logits = D(X)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
    auc = roc_auc_score(y.detach().cpu().numpy(), probs)
    preds = (probs >= 0.5).astype(np.float32)
    acc = (preds == y.detach().cpu().numpy()).mean()
    domain_error = 1.0 - acc

    # Importance weights for source samples: w = p_t(x) / p_s(x) ~ p(y=1|x)/(1-p(y=1|x))
    with torch.no_grad():
        logits_s = D(Xs.to(device))
        ps = torch.sigmoid(logits_s).clamp(min=1e-6, max=1 - 1e-6).detach().cpu().numpy()
    iw = ps / (1.0 - ps)
    return float(auc), iw, float(domain_error)


