import os
import sys
import time
import datetime
import argparse
import math

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

# --------------------------- Basic config / utils --------------------------- #

LOG_DIR = "./logs"
CHECKPOINTS_DIR = "./checkpoints"
CACHE_DIR = "./cache"  # kept for compatibility, not used here

SUPPORTED_DATASETS = ["mnist", "cifar10", "cifar100", "tiny-imagenet"]
DEFAULT_SEED = 42


def select_optimal_device():
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def set_deterministic_behavior(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class Tee:
    """
    Tee that duplicates all stdout writes to both terminal and a log file,
    flushing on every write so the log file is updated in (near) real time.
    """

    def __init__(self, filename, mode="w"):
        self.file = open(filename, mode)
        self.stdout = sys.stdout
        sys.stdout = self  # redirect global stdout to self

    def write(self, data):
        self.stdout.write(data)
        self.file.write(data)
        # Flush immediately so that the log file is updated in real time
        self.flush()

    def flush(self):
        self.stdout.flush()
        self.file.flush()

    def close(self):
        sys.stdout = self.stdout
        self.file.close()


def generate_log_and_ckpt_files(model, dataset_name, num_epochs):
    os.makedirs(LOG_DIR, exist_ok=True)
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)

    log_id = int(time.time())
    log_filename = f"{LOG_DIR}/instahide_{model}_{dataset_name}_epochs{num_epochs}-ID-{log_id}.log"
    checkpoint_file = (
        f"{CHECKPOINTS_DIR}/best_instahide_classifier_{model}_{dataset_name}_{num_epochs}epochs.pth"
    )

    return log_id, log_filename, checkpoint_file


# --------------------------- Datasets & transforms --------------------------- #


def get_transforms(dataset_type: str):
    """
    Transforms compatible with ImageNet-pretrained ResNet backbones.
    NOTE: Resize(224) is heavy; you can reduce to 64/96 if memory is an issue.
    """
    dataset_type = dataset_type.lower()
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]

    if dataset_type in ["cifar10", "cifar100", "tiny-imagenet"]:
        return transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
        ])
    elif dataset_type == "mnist":
        # ResNet expects 3-channel; convert and normalize
        return transforms.Compose([
            transforms.Resize(224),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
        ])
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")

def get_dataset(dataset_type: str):
    dataset_type = dataset_type.lower()
    transform = get_transforms(dataset_type)

    if dataset_type == "mnist":
        train_dataset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform
        )
        num_classes = 10
    elif dataset_type == "cifar10":
        train_dataset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform
        )
        num_classes = 10
    elif dataset_type == "cifar100":
        train_dataset = torchvision.datasets.CIFAR100(
            root="./data", train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root="./data", train=False, download=True, transform=transform
        )
        num_classes = 100
    elif dataset_type == "tiny-imagenet":
        train_dir = os.path.join("./data", "tiny-imagenet-200", "train")
        test_dir = os.path.join("./data", "tiny-imagenet-200", "test")

        train_dataset = torchvision.datasets.ImageFolder(
            root=train_dir,
            transform=transform,
        )
        test_dataset = torchvision.datasets.ImageFolder(
            root=test_dir,
            transform=transform,
        )
        num_classes = 200
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")

    return train_dataset, test_dataset, num_classes


# --------------------------- Feature extractor (ResNet) --------------------------- #


def build_resnet_feature_extractor(model_name: str, device: torch.device):
    """
    Build a ResNet (18/34/50) backbone and remove the last THREE layers:
    - layer4
    - avgpool
    - fc

    So the output is taken after layer3, which gives spatial feature maps.
    """
    model_name = model_name.lower()

    if model_name == "resnet18":
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet18(weights=weights)
    elif model_name == "resnet34":
        weights = torchvision.models.ResNet34_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet34(weights=weights)
    elif model_name == "resnet50":
        weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
        resnet = torchvision.models.resnet50(weights=weights)
    else:
        raise ValueError("Model type not supported. Use resnet18, resnet34, or resnet50.")

    # children: [conv1, bn1, relu, maxpool, layer1, layer2, layer3, layer4, avgpool, fc]
    # We keep up to layer3 => cut last 3 modules (layer4, avgpool, fc)
    modules = list(resnet.children())[:-3]
    feature_extractor = nn.Sequential(*modules).to(device)
    feature_extractor.eval()

    return feature_extractor


# --------------------------- Classifier --------------------------- #


class DenseClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(DenseClassifier, self).__init__()
        self.num_classes = num_classes
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        return self.layers(x)

    def get_classifier_size(self):
        for k, v in list(self.layers.named_children()):
            print(f"Layer {int(k) + 1}: {v}")


# --------------------------- Feature-space noise helper --------------------------- #


def sample_ball_noise(shape, radius, device=None):
    """
    Sample noise uniformly from a ball of radius `radius` in R^shape[1].
    """
    noise = torch.randn(*shape, device=device)
    norm = noise.norm(dim=1, keepdim=True)
    norm = torch.where(norm == 0, torch.ones_like(norm), norm)
    return noise / norm * radius


# --------------------------- Feature-map grid visualization --------------------------- #


def get_feature_maps_for_first_n(
    feature_extractor,
    dataset,
    device,
    num_samples=16,
):
    """
    Return feature maps for the first `num_samples` samples in the dataset.
    Used only for visualization, so we keep it small.
    """
    num_samples = min(num_samples, len(dataset))
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=False)
    imgs, _ = next(iter(loader))
    imgs = imgs.to(device)

    feature_extractor.eval()
    with torch.no_grad():
        fmap = feature_extractor(imgs)  # (num_samples, C, H, W)

    return fmap.cpu()


def feature_map_to_squared_grid(feat_map: torch.Tensor) -> np.ndarray:
    """
    Convert a feature map of shape (C, H, W) into a single 2D image consisting
    of all channels arranged in a squared grid, with NO spaces between
    the small heatmaps.

    Returns:
        grid_np: numpy array of shape (grid_H, grid_W) in [0, 1]
    """
    assert feat_map.dim() == 3, "feat_map must have shape (C, H, W)"
    C, H, W = feat_map.shape

    # Number of cells per side of the grid
    grid_size = int(math.ceil(math.sqrt(C)))

    # Initialize big canvas
    grid = torch.zeros(grid_size * H, grid_size * W, dtype=torch.float32)

    # Fill each cell with a normalized channel map
    for idx in range(C):
        ch = feat_map[idx]
        ch = ch - ch.min()
        denom = ch.max()
        if denom > 0:
            ch = ch / denom
        row = idx // grid_size
        col = idx % grid_size
        grid[row * H:(row + 1) * H, col * W:(col + 1) * W] = ch

    # Convert to numpy
    grid_np = grid.cpu().numpy()
    return grid_np


def save_featuremap_grid_and_mixup_pairs(
    original_feature_maps: torch.Tensor,
    radius: float,
    output_dir: str = "./mixup_pairs_features",
    num_pairs: int = 5,
):
    """
    Save num_pairs images to disk, each containing:
    - left: squared grid (all channels) from the original feature map
    - right: squared grid (all channels) from a synthetic mixup feature map

    Here we only use a small subset of feature maps for visualization,
    so memory usage stays tiny.
    """
    os.makedirs(output_dir, exist_ok=True)

    n, C, H, W = original_feature_maps.shape
    num_pairs = min(num_pairs, n)
    flat_feats = original_feature_maps.view(n, -1)  # (n, D)

    for i in range(num_pairs):
        idx = np.random.randint(0, n)
        j = (idx + np.random.randint(1, n)) % n  # ensure j != idx

        # Original feature map
        orig_map = original_feature_maps[idx]  # (C, H, W)
        orig_grid = feature_map_to_squared_grid(orig_map)

        # Build a single mixup sample in feature space for visualization
        anchor = flat_feats[idx]
        other = flat_feats[j]
        noise = sample_ball_noise((1, other.numel()), radius, device=other.device).view_as(other)
        lam = 0.3 + 0.4 * torch.rand(1, device=other.device)  # in [0.3, 0.7]
        mix_flat = lam * anchor + (1.0 - lam) * (other + noise)
        mix_map = mix_flat.view(C, H, W)
        mix_grid = feature_map_to_squared_grid(mix_map)

        # Plot and save pair
        fig, axes = plt.subplots(1, 2, figsize=(8, 4))

        axes[0].imshow(orig_grid, cmap="viridis", aspect="equal")
        axes[0].set_title(f"Original feat (idx={idx})")
        axes[0].axis("off")

        axes[1].imshow(mix_grid, cmap="viridis", aspect="equal")
        axes[1].set_title(f"Mixup feat (idx={idx}, j={j})")
        axes[1].axis("off")

        # Remove extra spaces around the big grids
        fig.subplots_adjust(left=0.01, right=0.99, top=0.95, bottom=0.05, wspace=0.05, hspace=0.0)

        save_path = os.path.join(output_dir, f"featuremap_grid_pair_{i+1}.png")
        fig.savefig(save_path, dpi=150)
        plt.close(fig)

        print(f"Saved feature-map grid mixup pair {i+1} to: {save_path}")


# --------------------------- Batch-wise feature-space mixup --------------------------- #


def mixup_batch_in_feature_space(
    feature_extractor,
    imgs,
    labels,
    device,
    radius,
    num_classes,
):
    """
    Perform feature-space mixup for a batch, without ever storing global features.

    - Extract feature maps for the whole batch.
    - Flatten to (B, D).
    - Pair each sample with a partner in the same batch (cyclic shift).
    - Add ball noise around the partner feature.
    - Mix with coefficients in [0.3, 0.7].
    - Return mixed features and soft labels.
    """
    imgs = imgs.to(device)
    labels = labels.to(device)

    feature_extractor.eval()
    with torch.no_grad():
        fmap = feature_extractor(imgs)           # (B, C, H, W)
        feats = fmap.view(fmap.size(0), -1)      # (B, D)

    B, D = feats.shape

    # Edge-case: single sample in batch – fall back to no mixup
    if B == 1:
        lam = torch.tensor([0.5], device=device)
        noise = sample_ball_noise((1, D), radius, device=device)
        mix_feats = lam.view(1, 1) * feats + (1.0 - lam).view(1, 1) * (feats + noise)
        soft_labels = torch.zeros(1, num_classes, device=device)
        soft_labels[0, labels[0]] = 1.0
        return mix_feats, soft_labels

    indices = torch.arange(B, device=device)
    # Simple, deterministic derangement: each sample pairs with the previous one
    partner = torch.roll(indices, shifts=1)

    # Mix coefficients in [0.3, 0.7] so neither dominates
    lam = 0.3 + 0.4 * torch.rand(B, device=device)  # (B,)
    noise = sample_ball_noise((B, D), radius, device=device)

    perturbed = feats[partner] + noise
    lam_view = lam.view(B, 1)
    mix_feats = lam_view * feats + (1.0 - lam_view) * perturbed

    # Build soft labels
    soft_labels = torch.zeros(B, num_classes, device=device)
    soft_labels[indices, labels] += lam
    soft_labels[indices, labels[partner]] += (1.0 - lam)

    return mix_feats, soft_labels


# --------------------------- Training / evaluation --------------------------- #


def train_and_eval(
    device,
    feature_extractor,
    classifier,
    optimizer,
    scheduler,
    trainloader,
    testloader,
    num_epochs,
    checkpoint_file,
    radius,
    num_classes,
) -> None:
    best_acc = 0.0
    classifier.to(device)
    feature_extractor.eval()

    for epoch in range(num_epochs):
        epoch_start = time.time()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            torch.mps.empty_cache()

        classifier.train()
        train_loss = 0.0

        # -------------------- Training loop (feature-space mixup per batch) -------------------- #
        for imgs, labels in trainloader:
            optimizer.zero_grad()

            mix_feats, mix_soft_labels = mixup_batch_in_feature_space(
                feature_extractor,
                imgs,
                labels,
                device,
                radius,
                num_classes,
            )

            outputs = classifier(mix_feats)
            # Training labels are soft (mixup): one-hot-like distribution
            loss = -(mix_soft_labels * torch.log_softmax(outputs, dim=1)).sum(dim=1).mean()
            train_loss += loss.item() * mix_feats.size(0)
            loss.backward()
            optimizer.step()

        train_loss /= len(trainloader.dataset)
        scheduler.step()

        # -------------------- Evaluation (hard labels, on-the-fly features) -------------------- #
        classifier.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in testloader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                fmap = feature_extractor(imgs)
                feats = fmap.view(fmap.size(0), -1)
                outputs = classifier(feats)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        test_acc = correct / total * 100.0

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(classifier.state_dict(), checkpoint_file)

        epoch_time = time.time() - epoch_start
        print(
            f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
            f"Test Acc: {test_acc:.2f}% | Time: {epoch_time:.2f}s | "
            f"LR: {optimizer.param_groups[0]['lr']:.8f}"
        )

    print(f"Best test accuracy: {best_acc:.2f}%")


def test_model(classifier, feature_extractor, device, test_loader):
    """
    Final evaluation with best saved classifier, computing features on demand.
    """
    classifier.eval()
    classifier.to(device)
    feature_extractor.eval()

    eval_start = time.time()
    with torch.no_grad():
        total = 0
        correct = 0
        for imgs, labels in test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            fmap = feature_extractor(imgs)
            feats = fmap.view(fmap.size(0), -1)
            y = classifier(feats)
            preds = torch.argmax(y, dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.shape[0]

    acc = correct / total * 100.0
    eval_finish = time.time() - eval_start
    print(f"Final test accuracy with best model: {acc:.2f}% [{eval_finish:.2f} s]")


# --------------------------- Main workflow --------------------------- #


def augment_workflow(
    feature_extractor_type: str,
    dataset_type: str,
    classifier_type: str,
    device: torch.device,
    num_epochs: int,
    batch_size: int,
    radius: float,
    lr: float,
    bench_mode: bool,
    finish_with_model_eval: bool = True,
    save_mixup_datasets: bool = False,  # kept for API compatibility, not used
) -> None:
    log_id, log_filename, checkpoint_file = generate_log_and_ckpt_files(
        feature_extractor_type, dataset_type, num_epochs
    )

    # ALWAYS log to file using Tee
    tee = Tee(log_filename)
    print(f"Saving logs to: {log_filename}")

    print(f"------------ Augmented Mixup (Feature Space, memory-efficient) ------------")
    print(f"Started on: {datetime.datetime.now().ctime()} (Log ID: #{log_id})")
    print(f"Platform: {os.uname().nodename}")
    print(f"Torch seed: {torch.random.initial_seed()}")
    print(f"Deterministic behavior: {bench_mode}")
    print(f"Using device: {device}")
    print(f"Model: {feature_extractor_type} [cut after layer3]")
    print(f"Dataset: {dataset_type}")
    print(f"Epochs: {num_epochs}")
    print(f"radius (r): {radius}")
    print(f"--------------------------------------------------------------------------")

    workflow_start_time = time.time()

    # 1) Datasets
    train_dataset, test_dataset, num_classes = get_dataset(dataset_type)

    # 2) Feature extractor (ResNet backbone cut after layer3)
    feature_extractor = build_resnet_feature_extractor(
        feature_extractor_type, device
    )

    # 3) Visualize a few pairs of original/mixup feature-map grids BEFORE training
    print("Saving original/mixup feature-map grid pairs before training begins...")
    small_feature_maps = get_feature_maps_for_first_n(
        feature_extractor, train_dataset, device, num_samples=16
    )
    save_featuremap_grid_and_mixup_pairs(
        original_feature_maps=small_feature_maps,
        radius=radius,
        output_dir="./mixup_pairs_features",
        num_pairs=5,
    )

    # 4) Standard image loaders; features are computed on-the-fly inside training/eval
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    # 5) Classifier on top of feature vectors
    if classifier_type == "dense":
        # Infer embedding dimension from a single sample
        with torch.no_grad():
            sample_img, _ = train_dataset[0]
            sample_img = sample_img.unsqueeze(0).to(device)
            fmap = feature_extractor(sample_img)
            embedding_dim = fmap.view(1, -1).size(1)
        classifier = DenseClassifier(embedding_dim, num_classes).to(device)
    else:
        raise NotImplementedError("Only a DenseClassifier is supported...")

    optimizer = optim.SGD(
        classifier.parameters(),
        lr=lr,
        momentum=0.9,
        weight_decay=1e-4,
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)

    print(f"--------------- Classifier size ---------------")
    classifier.get_classifier_size()
    print(f"-----------------------------------------------")
    print(
        f"Training classifier on feature-space mixup ({feature_extractor_type}, {dataset_type}) "
        f"on {device} for {num_epochs} epochs"
    )

    # 6) Train & evaluate (features and mixup on demand)
    train_and_eval(
        device=device,
        feature_extractor=feature_extractor,
        classifier=classifier,
        optimizer=optimizer,
        scheduler=scheduler,
        trainloader=train_loader,
        testloader=test_loader,
        num_epochs=num_epochs,
        checkpoint_file=checkpoint_file,
        radius=radius,
        num_classes=num_classes,
    )

    workflow_duration = time.time() - workflow_start_time
    print(f"Total execution time: {workflow_duration:.2f} s")
    print(f"-----------------------------------------------")

    # Load best classifier and test (again, features on demand)
    if finish_with_model_eval:
        classifier.load_state_dict(torch.load(checkpoint_file, map_location=device))
        test_model(classifier, feature_extractor, device, test_loader)

    # ALWAYS close Tee at the end
    tee.close()


# --------------------------- CLI --------------------------- #


if __name__ == "__main__":
    batch_size = 128
    device = select_optimal_device()

    parser = argparse.ArgumentParser(
        prog="Augmented Mixup", description="Augmented Mixup (Feature Space, memory-efficient)"
    )
    parser.add_argument(
        "--radius",
        type=float,
        default=1.0,
        help="Mixing radius r in feature space.",
    )
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default="resnet18",
        help="The ResNet architecture to use for feature extraction (resnet18, resnet34, resnet50).",
    )
    parser.add_argument(
        "-d",
        "--data",
        type=str,
        default="mnist",
        help="Dataset to use: mnist, cifar10, cifar100, tiny-imagenet.",
    )
    parser.add_argument(
        "-e",
        "--epochs",
        type=int,
        default=200,
        help="Number of training epochs.",
    )
    parser.add_argument(
        "-b",
        "--bench",
        action="store_true",
        help="Set a seed and deterministic PyTorch settings for reproducibility.",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.001,
        help="Learning rate for the classifier.",
    )

    args = parser.parse_args()
    args.data = str(args.data).lower()
    args.model = str(args.model).lower()

    if args.data not in SUPPORTED_DATASETS:
        raise ValueError(
            f"Dataset type not supported. Currently supported datasets: {SUPPORTED_DATASETS}"
        )
    if args.model not in ["resnet18", "resnet34", "resnet50"]:
        raise ValueError("Model type not supported. Use resnet18, resnet34, or resnet50.")

    if args.bench:
        set_deterministic_behavior(DEFAULT_SEED)

    augment_workflow(
        feature_extractor_type=args.model,
        dataset_type=args.data,
        classifier_type="dense",
        device=device,
        num_epochs=args.epochs,
        batch_size=batch_size,
        radius=args.radius,
        lr=args.lr,
        bench_mode=args.bench,
        save_mixup_datasets=False,
    )
