import os
import sys
import time
import datetime
import argparse
import math
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# --------------------------- Basic config / utils --------------------------- #

LOG_DIR = "./logs"
CHECKPOINTS_DIR = "./checkpoints"
DEFAULT_SEED = 42


def select_optimal_device():
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def set_deterministic_behavior(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    if hasattr(torch.backends, "cudnn"):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


class Tee:
    """
    Tee that duplicates all stdout writes to both terminal and a log file,
    flushing on every write so the log file is updated in (near) real time.
    """

    def __init__(self, filename, mode="w"):
        self.file = open(filename, mode)
        self.stdout = sys.stdout
        sys.stdout = self  # redirect global stdout to self

    def write(self, data):
        self.stdout.write(data)
        self.file.write(data)
        # Flush immediately so that the log file is updated in real time
        self.flush()

    def flush(self):
        self.stdout.flush()
        self.file.flush()

    def close(self):
        sys.stdout = self.stdout
        self.file.close()


def generate_unified_log_filename(tag: str, model: str, dataset_name: str, num_epochs: int) -> str:
    os.makedirs(LOG_DIR, exist_ok=True)
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    log_id = int(time.time())
    log_filename = f"{LOG_DIR}/federated_instahide_{tag}_{model}_{dataset_name}_epochs{num_epochs}-ID-{log_id}.log"
    return log_filename


def checkpoint_path(prefix: str, model: str, dataset_name: str, num_epochs: int, num_parties: int) -> str:
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    return os.path.join(
        CHECKPOINTS_DIR,
        f"{prefix}_{model}_{dataset_name}_{num_epochs}epochs_{num_parties}parties.pth"
    )


# --------------------------- Datasets & transforms --------------------------- #

def get_transforms():
    """
    Transforms compatible with ImageNet-pretrained ResNet backbones.
    NOTE: Resize(224) is heavy; reduce to 64/96 if memory is an issue.
    """
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    return transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
    ])


def get_cifar10():
    transform = get_transforms()
    train_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    num_classes = 10
    return train_dataset, test_dataset, num_classes


# --------------------------- Feature extractor (ResNet-18) --------------------------- #

def build_resnet18_feature_extractor(device: torch.device):
    """
    Build a ResNet-18 backbone and remove the last THREE layers:
    - layer4
    - avgpool
    - fc

    So the output is taken after layer3, which gives spatial feature maps.
    """
    weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
    resnet = torchvision.models.resnet18(weights=weights)

    # children: [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc]
    # We keep up to layer3 => cut last 3 modules (layer4, avgpool, fc)
    modules = list(resnet.children())[:-2]
    feature_extractor = nn.Sequential(*modules).to(device)
    feature_extractor.eval()
    return feature_extractor


# --------------------------- Classifier --------------------------- #

class DenseClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(DenseClassifier, self).__init__()
        self.num_classes = num_classes
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        return self.layers(x)

    def get_classifier_size(self):
        for k, v in list(self.layers.named_children()):
            print(f"Layer {int(k) + 1}: {v}")


# --------------------------- Feature-space noise helper --------------------------- #

def sample_ball_noise(shape, radius, device=None):
    """
    Sample noise uniformly from a ball of radius `radius` in R^shape[1].
    """
    noise = torch.randn(*shape, device=device)
    norm = noise.norm(dim=1, keepdim=True)
    norm = torch.where(norm == 0, torch.ones_like(norm), norm)
    return noise / norm * radius


# --------------------------- Feature-map grid visualization --------------------------- #

def get_feature_maps_for_first_n(
    feature_extractor,
    dataset,
    device,
    num_samples=16,
):
    """
    Return feature maps for the first `num_samples` samples in the dataset.
    Used only for visualization, so we keep it small.
    """
    num_samples = min(num_samples, len(dataset))
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=False)
    imgs, _ = next(iter(loader))
    imgs = imgs.to(device)

    feature_extractor.eval()
    with torch.no_grad():
        fmap = feature_extractor(imgs)  # (num_samples, C, H, W)

    return fmap.cpu()


def feature_map_to_squared_grid(feat_map: torch.Tensor) -> np.ndarray:
    """
    Arrange channels of a feature map (C, H, W) into a tight square grid image.
    """
    assert feat_map.dim() == 3, "feat_map must have shape (C, H, W)"
    C, H, W = feat_map.shape
    grid_size = int(math.ceil(math.sqrt(C)))
    grid = torch.zeros(grid_size * H, grid_size * W, dtype=torch.float32)
    for idx in range(C):
        ch = feat_map[idx]
        ch = ch - ch.min()
        denom = ch.max()
        if denom > 0:
            ch = ch / denom
        row = idx // grid_size
        col = idx % grid_size
        grid[row * H:(row + 1) * H, col * W:(col + 1) * W] = ch
    return grid.cpu().numpy()


def save_featuremap_grid_and_mixup_pairs(
    original_feature_maps: torch.Tensor,
    radius: float,
    output_dir: str = "./mixup_pairs_features",
    num_pairs: int = 5,
):
    """
    Save num_pairs images (original vs a synthetic mixup feature map) for visualization only.
    """
    os.makedirs(output_dir, exist_ok=True)
    n, C, H, W = original_feature_maps.shape
    num_pairs = min(num_pairs, n)
    flat_feats = original_feature_maps.view(n, -1)  # (n, D)

    for i in range(num_pairs):
        idx = np.random.randint(0, n)
        j = (idx + np.random.randint(1, n)) % n  # ensure j != idx

        orig_map = original_feature_maps[idx]  # (C, H, W)
        orig_grid = feature_map_to_squared_grid(orig_map)

        anchor = flat_feats[idx]
        other = flat_feats[j]
        noise = sample_ball_noise((1, other.numel()), radius, device=other.device).view_as(other)
        lam = 0.3 + 0.4 * torch.rand(1, device=other.device)  # [0.3, 0.7]
        mix_flat = lam * anchor + (1.0 - lam) * (other + noise)
        mix_map = mix_flat.view(C, H, W)
        mix_grid = feature_map_to_squared_grid(mix_map)

        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
        axes[0].imshow(orig_grid, cmap="viridis", aspect="equal")
        axes[0].set_title(f"Original feat (idx={idx})")
        axes[0].axis("off")
        axes[1].imshow(mix_grid, cmap="viridis", aspect="equal")
        axes[1].set_title(f"Mixup feat (idx={idx}, j={j})")
        axes[1].axis("off")
        fig.subplots_adjust(left=0.01, right=0.99, top=0.95, bottom=0.05, wspace=0.05, hspace=0.0)

        save_path = os.path.join(output_dir, f"featuremap_grid_pair_{i+1}.png")
        fig.savefig(save_path, dpi=150)
        plt.close(fig)
        print(f"Saved feature-map grid mixup pair {i+1} to: {save_path}")


# --------------------------- Batch-wise feature-space mixup --------------------------- #

def mixup_batch_in_feature_space(
    feature_extractor,
    imgs,
    labels,
    device,
    radius,
    num_classes,
):
    """
    Perform feature-space mixup for a batch, without ever storing global features.
    """
    imgs = imgs.to(device)
    labels = labels.to(device)

    feature_extractor.eval()
    with torch.no_grad():
        fmap = feature_extractor(imgs)           # (B, C, H, W)
        feats = fmap.view(fmap.size(0), -1)      # (B, D)

    B, D = feats.shape

    # Edge-case: single sample in batch – fall back to no mixup
    if B == 1:
        lam = torch.tensor([0.5], device=device)
        noise = sample_ball_noise((1, D), radius, device=device)
        mix_feats = lam.view(1, 1) * feats + (1.0 - lam).view(1, 1) * (feats + noise)
        soft_labels = torch.zeros(1, num_classes, device=device)
        soft_labels[0, labels[0]] = 1.0
        return mix_feats, soft_labels

    indices = torch.arange(B, device=device)
    partner = torch.roll(indices, shifts=1)  # cyclic shift pairing

    lam = 0.3 + 0.4 * torch.rand(B, device=device)  # (B,)
    noise = sample_ball_noise((B, D), radius, device=device)

    perturbed = feats[partner] + noise
    lam_view = lam.view(B, 1)
    mix_feats = lam_view * feats + (1.0 - lam_view) * perturbed

    soft_labels = torch.zeros(B, num_classes, device=device)
    soft_labels[indices, labels] += lam
    soft_labels[indices, labels[partner]] += (1.0 - lam)

    return mix_feats, soft_labels


# --------------------------- Helpers for federated split --------------------------- #

def split_dataset_equally(dataset, num_parties: int, seed: int = 0) -> List[Subset]:
    """
    Randomly split a dataset into `num_parties` equally-sized (as close as possible) disjoint subsets.
    """
    n = len(dataset)
    lengths = [n // num_parties] * num_parties
    for i in range(n % num_parties):
        lengths[i] += 1
    g = torch.Generator()
    g.manual_seed(seed)
    indices = torch.randperm(n, generator=g).tolist()
    subsets = []
    start = 0
    for length in lengths:
        end = start + length
        subsets.append(Subset(dataset, indices[start:end]))
        start = end
    return subsets


# --------------------------- Training / evaluation core --------------------------- #

def cross_entropy_with_soft_targets(logits: torch.Tensor, soft_targets: torch.Tensor) -> torch.Tensor:
    return -(soft_targets * torch.log_softmax(logits, dim=1)).sum(dim=1).mean()


def evaluate_classifier(classifier, feature_extractor, device, testloader) -> float:
    classifier.eval()
    feature_extractor.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in testloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            fmap = feature_extractor(imgs)
            feats = fmap.view(fmap.size(0), -1)
            outputs = classifier(feats)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / max(1, total)


def train_classifier_standard(
    device,
    feature_extractor,
    classifier,
    optimizer,
    scheduler,
    trainloader,
    testloader,
    num_epochs,
) -> float:
    """
    Standard training (NO mixup), features computed on-the-fly.
    Returns best test accuracy.
    """
    best_acc = 0.0
    classifier.to(device)
    feature_extractor.eval()

    for epoch in range(num_epochs):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            torch.mps.empty_cache()

        classifier.train()
        train_loss = 0.0

        for imgs, labels in trainloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                fmap = feature_extractor(imgs)
                feats = fmap.view(fmap.size(0), -1)

            optimizer.zero_grad()
            outputs = classifier(feats)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * labels.size(0)

        scheduler.step()
        train_loss /= len(trainloader.dataset)
        test_acc = evaluate_classifier(classifier, feature_extractor, device, testloader)
        best_acc = max(best_acc, test_acc)
        print(
            f"[Std] Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
            f"Test Acc: {test_acc:.2f}% | LR: {optimizer.param_groups[0]['lr']:.8f}"
        )
    return best_acc


def train_classifier_mixup_union(
    device: torch.device,
    feature_extractor,
    classifier,
    optimizer,
    scheduler,
    party_loaders: List[DataLoader],
    testloader: DataLoader,
    num_epochs: int,
    radius: float,
    num_classes: int,
) -> float:
    """
    Train on the UNION of party-local mixup batches (mix within each party only),
    iterating over all parties each epoch. Returns best test accuracy.
    """
    best_acc = 0.0
    classifier.to(device)
    feature_extractor.eval()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            torch.mps.empty_cache()

        classifier.train()
        running_loss = 0.0
        processed = 0

        for pidx, loader in enumerate(party_loaders):
            for imgs, labels in loader:
                optimizer.zero_grad()
                mix_feats, mix_soft_labels = mixup_batch_in_feature_space(
                    feature_extractor, imgs, labels, device, radius, num_classes
                )
                outputs = classifier(mix_feats)
                loss = cross_entropy_with_soft_targets(outputs, mix_soft_labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * mix_feats.size(0)
                processed += mix_feats.size(0)

        scheduler.step()
        train_loss = running_loss / max(1, processed)
        test_acc = evaluate_classifier(classifier, feature_extractor, device, testloader)
        best_acc = max(best_acc, test_acc)

        epoch_time = time.time() - epoch_start
        print(
            f"[Mixup-Union] Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
            f"Test Acc: {test_acc:.2f}% | Time: {epoch_time:.2f}s | LR: {optimizer.param_groups[0]['lr']:.8f}"
        )

    return best_acc


# --------------------------- Main experiment runner --------------------------- #

def augment_workflow(
    device: torch.device,
    num_epochs: int,
    batch_size: int,
    radius: float,
    lr: float,
    bench_mode: bool,
    num_parties: int,
    finish_with_model_eval: bool = True,
    tee: Optional[Tee] = None,
) -> None:
    """
    Runs a single experiment for a given num_parties.
    Logging is assumed to be already redirected by a Tee created outside.
    """
    print(f"\n============ Experiment START (num_parties={num_parties}) ============")
    print(f"Started on: {datetime.datetime.now().ctime()}")
    try:
        uname = os.uname().nodename
    except Exception:
        uname = "unknown"
    print(f"Platform: {uname}")
    print(f"Torch seed: {torch.random.initial_seed()}")
    print(f"Deterministic behavior: {bench_mode}")
    print(f"Using device: {device}")
    print(f"Model: resnet18 [cut after layer3]")
    print(f"Dataset: cifar10")
    print(f"Epochs: {num_epochs}")
    print(f"radius (r): {radius}")
    print(f"num_parties: {num_parties}")
    print("----------------------------------------------------")

    workflow_start_time = time.time()

    # 1) Dataset
    train_dataset, test_dataset, num_classes = get_cifar10()

    # 2) Feature extractor
    feature_extractor = build_resnet18_feature_extractor(device)

    # 3) Optional visualization (small)
    print("Saving original/mixup feature-map grid pairs before training begins...")
    small_feature_maps = get_feature_maps_for_first_n(
        feature_extractor, train_dataset, device, num_samples=16
    )
    vis_out_dir = f"./mixup_pairs_features_np{num_parties}"
    save_featuremap_grid_and_mixup_pairs(
        original_feature_maps=small_feature_maps,
        radius=radius,
        output_dir=vis_out_dir,
        num_pairs=5,
    )

    # 4) Federated split
    seed_for_split = DEFAULT_SEED if bench_mode else int(time.time()) % (2**31 - 1)
    party_subsets = split_dataset_equally(train_dataset, num_parties=num_parties, seed=seed_for_split)

    # We'll use party 0 for the baseline (no mixup)
    party0_loader = DataLoader(
        party_subsets[0],
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    # All parties for the mixup-union experiment
    party_loaders = [
        DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        for subset in party_subsets
    ]

    # Global test loader
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    # Discover embedding dimension once
    with torch.no_grad():
        sample_img, _ = train_dataset[0]
        sample_img = sample_img.unsqueeze(0).to(device)
        fmap = feature_extractor(sample_img)
        embedding_dim = fmap.view(1, -1).size(1)

    # -------------------- Baseline: train on ONE shard (party 0), NO mixup -------------------- #
    print("\n==================== Baseline: single-party (no mixup) ====================")
    classifier_baseline = DenseClassifier(embedding_dim, num_classes).to(device)
    optimizer_baseline = optim.SGD(
        classifier_baseline.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4
    )
    scheduler_baseline = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_baseline, T_max=num_epochs, eta_min=1e-5
    )

    best_acc_baseline = train_classifier_standard(
        device=device,
        feature_extractor=feature_extractor,
        classifier=classifier_baseline,
        optimizer=optimizer_baseline,
        scheduler=scheduler_baseline,
        trainloader=party0_loader,
        testloader=test_loader,
        num_epochs=num_epochs,
    )

    baseline_ckpt = checkpoint_path(
        prefix="single_party_no_mixup",
        model="resnet18",
        dataset_name="cifar10",
        num_epochs=num_epochs,
        num_parties=num_parties,
    )
    torch.save(classifier_baseline.state_dict(), baseline_ckpt)
    print(f"[Baseline] Best test accuracy (party 0, no mixup): {best_acc_baseline:.2f}%")
    print(f"[Baseline] Saved checkpoint: {baseline_ckpt}")

    # -------------------- Mixup-Union: train on UNION of parties (feature-space) -------------------- #
    print("\n==================== Mixup-Union: all parties (feature-space) ====================")
    classifier_union = DenseClassifier(embedding_dim, num_classes).to(device)
    optimizer_union = optim.SGD(
        classifier_union.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4
    )
    scheduler_union = optim.lr_scheduler.CosineAnnealingLR(
        optimizer_union, T_max=num_epochs, eta_min=1e-5
    )

    best_acc_union = train_classifier_mixup_union(
        device=device,
        feature_extractor=feature_extractor,
        classifier=classifier_union,
        optimizer=optimizer_union,
        scheduler=scheduler_union,
        party_loaders=party_loaders,
        testloader=test_loader,
        num_epochs=num_epochs,
        radius=radius,
        num_classes=num_classes,
    )

    union_ckpt = checkpoint_path(
        prefix="mixup_union",
        model="resnet18",
        dataset_name="cifar10",
        num_epochs=num_epochs,
        num_parties=num_parties,
    )
    torch.save(classifier_union.state_dict(), union_ckpt)
    print(f"[Mixup-Union] Best test accuracy: {best_acc_union:.2f}%")
    print(f"[Mixup-Union] Saved checkpoint: {union_ckpt}")

    # Summary for this experiment
    print("\n==================== Experiment Summary ====================")
    print(f"num_parties: {num_parties}")
    print(f"Baseline (party 0, no mixup) best test acc: {best_acc_baseline:.2f}%")
    print(f"Mixup-Union (all parties) best test acc:   {best_acc_union:.2f}%")
    print("============================================================")

    workflow_duration = time.time() - workflow_start_time
    print(f"Experiment (num_parties={num_parties}) execution time: {workflow_duration:.2f} s")

    # Optional final eval (uses already-trained models in memory)
    if finish_with_model_eval:
        print("\nRe-evaluating both models once more (no checkpoint reload)...")
        acc_base_final = evaluate_classifier(classifier_baseline, feature_extractor, device, test_loader)
        acc_union_final = evaluate_classifier(classifier_union, feature_extractor, device, test_loader)
        print(f"Baseline final test accuracy:   {acc_base_final:.2f}%")
        print(f"Mixup-Union final test accuracy:{acc_union_final:.2f}%")

    print(f"============ Experiment END (num_parties={num_parties}) ============\n")


# --------------------------- CLI --------------------------- #

if __name__ == "__main__":
    batch_size = 128
    device = select_optimal_device()

    parser = argparse.ArgumentParser(
        prog="Augmented Mixup (Single-Party Baseline vs Mixup-Union)",
        description="Runs three experiments over num_parties ∈ {10, 20, 30} (no CLI arg for num_parties).",
    )
    parser.add_argument(
        "--radius",
        type=float,
        default=1.0,
        help="Mixing radius r in feature space.",
    )
    parser.add_argument(
        "-e",
        "--epochs",
        type=int,
        default=5,
        help="Number of training epochs.",
    )
    parser.add_argument(
        "-b",
        "--bench",
        action="store_true",
        help="Set a seed and deterministic PyTorch settings for reproducibility.",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        help="Learning rate for the classifier.",
    )

    args = parser.parse_args()

    if args.bench:
        set_deterministic_behavior(DEFAULT_SEED)

    # Create a single unified log file and redirect all stdout for all experiments
    unified_log = generate_unified_log_filename(
        tag="federated_multi",
        model="resnet18",
        dataset_name="cifar10",
        num_epochs=args.epochs,
    )
    tee = Tee(unified_log)
    print(f"Saving ALL experiments to a single log file: {unified_log}")
    print(f"Run started on: {datetime.datetime.now().ctime()}")

    try:
        experiments = [10, 20, 30]
        print(f"Planned experiments over num_parties: {experiments}")

        for np_val in experiments:
            augment_workflow(
                device=device,
                num_epochs=args.epochs,
                batch_size=batch_size,
                radius=args.radius,
                lr=args.lr,
                bench_mode=args.bench,
                num_parties=np_val,
                finish_with_model_eval=True,
                tee=tee,  # logging already redirected globally
            )

        print("\n==================== ALL EXPERIMENTS COMPLETED ====================")
        print(f"Experiments finished on: {datetime.datetime.now().ctime()}")
        print("===================================================================\n")
    finally:
        # Close tee and restore stdout regardless of outcomes
        tee.close()
