import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.linalg as LA
import numpy as np

import time
from tqdm import tqdm
import sys
import os
import datetime

import argparse
from typing import Tuple

# local imports
from utils import select_optimal_device, set_deterministic_behavior, cka, summarize, FeatureGenerator, Tee
from utils import DEFAULT_SEED, SUPPORTED_DATASETS, LOG_DIR, CHECKPOINTS_DIR, CACHE_DIR, inf

# for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def generate_log_and_ckpt_files(model, dataset_name, k_val, num_epochs):
    log_dir: str = LOG_DIR
    checkpoint_dir: str = CHECKPOINTS_DIR
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    log_id = int(time.time())
    log_filename = f"{log_dir}/instahide_{model}_{dataset_name}_k{k_val}_epochs{num_epochs}-ID-{log_id}.log"
    checkpoint_file = f'{checkpoint_dir}/best_instahide_classifier_{model}_{dataset_name}_k{k_val}_{num_epochs}epochs.pth'

    return log_id, log_filename, checkpoint_file


class DenseClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(DenseClassifier, self).__init__()
        self.num_classes = num_classes
        self.layers = nn.Sequential(
            nn.Linear(embedding_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.layers(x)

    def get_classifier_size(self):
        for k, v in list(self.layers.named_children()):
            print(f'Layer {int(k)+1}: {v}')


def compare_features(mixup_features, train_features, train_labels):
    mixup_embeddings = []
    original_embeddings = []

    mixup_dataloader = DataLoader(
        mixup_features, batch_size=128, shuffle=False)

    train_dataset = EmbeddingDataset(train_features, train_labels)
    original_dataloader = DataLoader(
        train_dataset, batch_size=128, shuffle=False)

    n_max = 5
    idx = 0
    for (mixup_images, _), (original_images, _) in tqdm(zip(mixup_dataloader, original_dataloader), desc='Creating comparison data'):
        if idx == n_max:
            break
        mixup_embeddings.append(mixup_images)
        original_embeddings.append(original_images)
        idx += 1

    mixup_embeddings = torch.cat(mixup_embeddings, dim=0)
    original_embeddings = torch.cat(original_embeddings, dim=0)

    l1_norm = LA.norm(mixup_embeddings-original_embeddings, dim=1, ord=1)
    l2_norm = LA.norm(mixup_embeddings-original_embeddings, dim=1, ord=2)
    cka_dist = cka(mixup_embeddings, original_embeddings)
    cos = torch.cosine_similarity(mixup_embeddings, original_embeddings, dim=1)
    summarize("cos sim", cos)
    summarize("L1 Norm", l1_norm)
    summarize("L2 Norm", l2_norm)
    print("CKA (global):", cka_dist.item())

    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
    bins = 25

    axs[0, 0].hist(l1_norm.numpy(), bins=bins, rwidth=0.8)
    axs[0, 0].set_title("L1 norm")
    axs[0, 0].axvline(l1_norm.mean().item(), color='red', linestyle='--')

    axs[0, 1].hist(l2_norm.numpy(), bins=bins, rwidth=0.8)
    axs[0, 1].set_title("L2 norm")
    axs[0, 1].axvline(l2_norm.mean().item(), color='red', linestyle='--')

    axs[1, 0].hist(cos.numpy(), bins=bins, rwidth=0.8)
    axs[1, 0].set_title("Cosine similarity")
    axs[1, 0].axvline(cos.mean().item(), color='red', linestyle='--')

    # Add labels for all subplots
    for ax in axs.flat:
        ax.set_xlabel("Value")
        ax.set_ylabel("Count")

    plt.tight_layout()
    plt.savefig("features-plot.png", dpi=350)
    plt.close()


def estimate_average_pairwise_distance_batched(
    embeddings: torch.Tensor,
    batch_size: int,
    num_samples: int = 1_000_000,
    deterministic: bool = False
) -> float:
    """
    Estimates the average pairwise Euclidean distance between embeddings
    by randomly sampling pairs, processed in batches.

    Args:
        embeddings: (N, D) tensor
        num_samples: number of random pairs to sample
        batch_size: number of pairs to process per batch
        seed: random seed for reproducibility

    Returns:
        Estimated average pairwise distance (float)
    """
    N = embeddings.shape[0]
    if N < 2:
        return 0.0

    if deterministic:
        torch.manual_seed(42)

    total_dist = 0.0
    total_count = 0

    samples_remaining = num_samples
    while samples_remaining > 0:
        current_batch_size = min(batch_size, samples_remaining)

        # Sample indices for pairs (i, j) with i != j
        idx1 = torch.randint(0, N, (current_batch_size,))
        idx2 = torch.randint(0, N, (current_batch_size,))
        mask = idx1 != idx2
        # If not enough valid pairs, resample until batch is filled
        while mask.sum() < current_batch_size:
            extra_needed = current_batch_size - mask.sum()
            extra1 = torch.randint(0, N, (extra_needed,))
            extra2 = torch.randint(0, N, (extra_needed,))
            extra_mask = extra1 != extra2
            idx1 = torch.cat([idx1[mask], extra1[extra_mask]])
            idx2 = torch.cat([idx2[mask], extra2[extra_mask]])
            idx1 = idx1[:current_batch_size]
            idx2 = idx2[:current_batch_size]
            mask = idx1 != idx2

        emb1 = embeddings[idx1]
        emb2 = embeddings[idx2]
        dists = torch.norm(emb1 - emb2, dim=1)
        total_dist += dists.sum().item()
        total_count += current_batch_size
        samples_remaining -= current_batch_size

    return total_dist / total_count if total_count > 0 else 0.0

    
def sample_k_vectors(dim, radius, k):
    random_matrix = torch.randn(k, dim)
    norms = torch.norm(random_matrix, dim=1, keepdim=True)
    scale = torch.minimum(torch.ones_like(norms), radius / norms)
    vectors = random_matrix * scale
    return vectors

class MixupEmbeddingBallDataset(Dataset):
    def __init__(self, embeddings, labels, multiplicative_factor, universal_radius, k, num_classes):
        if isinstance(embeddings, torch.Tensor):
            self.embeddings = embeddings
        else:
            self.embeddings = torch.tensor(embeddings)
        if isinstance(labels, torch.Tensor):
            self.labels = labels
        else:
            self.labels = torch.tensor(labels, dtype=torch.long)
        self.universal_radius = universal_radius
        self.multiplicative_factor = multiplicative_factor
        self.k = k
        self.n = len(self.embeddings)
        self.num_classes = num_classes
        self.device = self.embeddings.device
        self.dtype = self.embeddings.dtype

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        indices = np.random.choice(self.n, self.k-1, replace=False)
        mix_indices = [idx] + list(indices)
        mix_coeffs = np.random.dirichlet([1.0]*self.k)
        mix_embedding = torch.zeros_like(self.embeddings[0])
        mix_label = torch.zeros(self.num_classes)

        # Prepare centers and radii for k-1 non-anchor embeddings
        centers = []
        for i in mix_indices[1:]:
            emb = self.embeddings[i].detach().cpu().numpy()
            centers.append(emb)
        dim = self.embeddings.shape[1]
        if len(centers) > 0:
            noises = sample_k_vectors(
                dim, self.multiplicative_factor * self.universal_radius, self.k-1)
            perturbed = [center + noise for center,
                         noise in zip(centers, noises.numpy())]
        else:
            perturbed = []

        # Build the mix_embedding and mix_label
        for j, (i, coeff) in enumerate(zip(mix_indices, mix_coeffs)):
            label = int(self.labels[i])
            if i == idx:
                emb = self.embeddings[i]
            else:
                emb = torch.tensor(perturbed[j-1])
            mix_embedding += coeff * emb
            mix_label[label] += coeff
        return mix_embedding, mix_label


class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        if isinstance(embeddings, torch.Tensor):
            self.embeddings = embeddings
        else:
            self.embeddings = torch.tensor(embeddings)
        if isinstance(labels, torch.Tensor):
            self.labels = labels
        else:
            self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


def train_and_eval(
    device,
    model,
    optimizer,
    scheduler,
    trainloader,
    testloader,
    num_epochs,
    checkpoint_file,
) -> None:
    best_acc = 0
    model.to(device)

    for epoch in range(num_epochs):
        epoch_start = time.time()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if torch.mps.is_available():
            torch.mps.empty_cache()

        # Training
        model.train()
        train_loss = 0
        for features, labels in trainloader:
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(features)
            loss = -(labels * torch.log_softmax(outputs, dim=1)
                     ).sum(dim=1).mean()  # Mixup loss
            train_loss += loss.item() * features.size(0)
            loss.backward()
            optimizer.step()

        train_loss /= len(trainloader.dataset)
        scheduler.step()

        # Evaluation
        model.eval()
        correct = 0
        with torch.no_grad():
            for features, labels in testloader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()

        test_acc = correct / len(testloader.dataset) * 100.0

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), checkpoint_file)

        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | "
              f"Test Acc: {test_acc:.2f}% | Time: {epoch_time:.2f}s | "
              f"LR: {optimizer.param_groups[0]['lr']:.8f}")

    print(f"Best test accuracy: {best_acc:.2f}%")


def test_model(model, device, test_loader):
    model.eval()
    model.to(device)

    eval_start = time.time()
    with torch.no_grad():
        total = 0
        correct = 0
        for x, y_true in test_loader:
            x, y_true = x.to(device), y_true.to(device)
            y = model(x)

            preds = torch.argmax(y, dim=1)
            correct += (preds == y_true).sum().item()
            total += x.shape[0]

    acc = correct/total*100.0
    eval_finish = time.time() - eval_start
    print(
        f"Final test accuracy with best model: {acc:.2f}% [{eval_finish:.2f} s]")


def prepare_and_validate_features(
        feature_extractor_type: str,
        remove_last_k_layers: int,
        dataset_type: str,
        device: torch.device,
        use_cache: bool) -> Tuple[Tuple[str, str,], Tuple[str, str], int]:
    """
    - Uses a specific feature extractor to generate embeddings from datasets such as CIFAR10, CIFAR100, and so on.
    - The feature extractor is a ResNet model with the classifier layers removed (arbitrary number of layers can be removed, via the `remove_last_k_layers` argument).
    """
    fg = FeatureGenerator(feature_extractor_type=feature_extractor_type,
                          remove_last_k_layers=remove_last_k_layers,
                          dataset_type=dataset_type,
                          device=device,
                          use_cache=use_cache)
    return (fg.train_features_path, fg.train_labels_path), (fg.test_features_path, fg.test_labels_path), fg.num_classes


def augment_workflow(
    feature_extractor_type: str,
    dataset_type: str,
    classifier_type: str,
    device: torch.device,
    num_epochs: int,
    batch_size: int,
    k_val: int,
    multiplicative_factor: float,
    lr: float,
    bench_mode: bool,
    finish_with_model_eval: bool = True,
    save_logs: bool = False,
    save_mixup_datasets: bool = False,
    compare_features_only: bool = False,
) -> None:
    log_id, log_filename, checkpoint_file = generate_log_and_ckpt_files(
        feature_extractor_type, dataset_type, k_val, num_epochs)
    if save_logs:
        sys.stdout = Tee(log_filename)
        print(f'Saving logs to: {log_filename}')
    remove_last_k_layers = 2
    print(f"------------ Augmented Mixup ------------")
    print(f'Started on: {datetime.datetime.now().ctime()} (Log ID: #{log_id})')
    print(f'Platform: {os.uname().nodename}')
    print(f'Torch seed: {torch.random.initial_seed()}')
    print(f'Deterministic behavior: {bench_mode}')
    print(f"Using device: {device}")
    print(
        f'Model: {feature_extractor_type} [last {remove_last_k_layers} layers removed]')
    print(f'Dataset: {dataset_type}')
    print(f'Epochs: {num_epochs}')
    print(f'k_val: {k_val}')
    print(f'Multiplicative factor (mf): {multiplicative_factor}')
    print(f"-----------------------------------------------")

    mixup_train_dataset_file = f'{CACHE_DIR}/mixup-{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}-train_dataset.pth'
    mixup_test_dataset_file = f'{CACHE_DIR}/mixup-{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}-test_dataset.pth'

    workflow_start_time = time.time()
    train_files, test_files, num_classes = prepare_and_validate_features(
        feature_extractor_type=feature_extractor_type,
        remove_last_k_layers=remove_last_k_layers,
        dataset_type=dataset_type,
        device=device,
        use_cache=True)  # use cache by default

    train_features = torch.load(train_files[0])  # torch.Tensor
    train_labels = torch.load(train_files[1])  # torch.Tensor
    test_features = torch.load(test_files[0])  # torch.Tensor
    test_labels = torch.load(test_files[1])  # torch.Tensor

    # Execute the feature transformations on the CPU
    # _transformation_batch_size = len(train_features)//10
    _transformation_device = torch.device("cpu")
    train_features = train_features.view(
        # flatten the features
        train_features.shape[0], -1).contiguous().to(_transformation_device)
    test_features = test_features.view(
        # flatten the features
        test_features.shape[0], -1).contiguous().to(_transformation_device)
    train_labels = train_labels.to(_transformation_device)
    test_labels = test_labels.to(_transformation_device)

    # Generate Synthetic Embeddings
    universal_radius = estimate_average_pairwise_distance_batched(
        train_features, batch_size)
    print(f'Computed (average) universal radius: {universal_radius}')

    train_dataset = MixupEmbeddingBallDataset(
        embeddings=train_features,
        labels=train_labels,
        multiplicative_factor=multiplicative_factor,
        universal_radius=universal_radius,
        k=k_val,
        num_classes=num_classes)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = EmbeddingDataset(test_features, test_labels)
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False)

    if save_mixup_datasets:
        torch.save(train_dataset, mixup_train_dataset_file)
        torch.save(test_dataset, mixup_test_dataset_file)

    if compare_features_only:
        compare_features(train_dataset, train_features, train_labels)
        sys.exit(0)

    if classifier_type == "dense":
        classifier = DenseClassifier(
            train_features.shape[1], num_classes).to(device)
    else:
        raise NotImplementedError(
            "Only a DenseClassifier is supported...")

    optimizer = optim.AdamW(classifier.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        
        optimizer, T_max=num_epochs)

    print(f'--------------- Classifier size ---------------')
    classifier.get_classifier_size()
    print(f'-----------------------------------------------')
    print(
        f'Training {feature_extractor_type}({dataset_type}) on {device} for {num_epochs} epochs')
    train_and_eval(device, classifier, optimizer, scheduler,
                   train_loader, test_loader, num_epochs, checkpoint_file)

    workflow_duration = time.time() - workflow_start_time
    print(f"Total execution time: {workflow_duration:.2f} s")
    print(f"-----------------------------------------------")
    if finish_with_model_eval:
        classifier.load_state_dict(torch.load(checkpoint_file))
        test_model(classifier, device, test_loader)
    # ===========================================

    # Restore stdout
    if save_logs:
        sys.stdout.file.close()
        sys.stdout = sys.stdout.stdout


if __name__ == "__main__":
    # ========== CONFIGURATION ==========
    batch_size = 128
    device = select_optimal_device()

    # ========== ARGUMENTS ==========
    parser = argparse.ArgumentParser(
        prog='Augmented Mixup',
        description='Augmented Mixup')
    parser.add_argument("-k", "--kval", type=int, default=4,
                        help="The number of random samples that will be used for generating the mixup dataset.")
    parser.add_argument("--mf", type=float, default=2.0,
                        help="Multiplicative factor. It will be used to scale the radius when sampling synthetic embeddings. Check `sample_k_vectors`")
    parser.add_argument("-m", "--model", type=str, default="resnet18",
                        help='The architecture of the neural network to be used for feature (embedding) extraction.')
    parser.add_argument("-d", "--data", type=str, default="mnist",
                        help='The dataset used for generating samples and training.')
    parser.add_argument("-e", "--epochs", type=int, default=200,
                        help="Number of training epochs.")
    parser.add_argument("-b", "--bench", action="store_true",
                        help="Set a seed and other useful PyTorch settings for the best reproducibility of experiments.")
    parser.add_argument("--lr", type=float, default=0.001,
                        help="The learning rate that will be used during training.")
    parser.add_argument("-v", "--verbose", action="store_false",
                        help="Do not save the logs to a file when performing the training and benchmark.")

    args = parser.parse_args()
    # model and dataset to lowercase
    args.data = str(args.data).lower()
    args.model = str(args.model).lower()

    # ========== SAFETY CHECKS ==========
    if args.data not in SUPPORTED_DATASETS:
        raise ValueError(
            f"Dataset type not supported. Currently supported datasets: {SUPPORTED_DATASETS}")
    if args.model not in ["resnet18", "resnet34", "resnet50"]:
        raise ValueError(
            f"Model type not supported.")
    assert args.kval >= 1, "The value of k should be strictly positive."

    # ========== REPRODUCIBILITY ==========
    if args.bench:
        seed = DEFAULT_SEED  # change this as needed
        set_deterministic_behavior(seed)

    # Set the classifier type to either "dense" or "conv"
    augment_workflow(feature_extractor_type=args.model,
                     dataset_type=args.data,
                     classifier_type="dense",
                     device=device,
                     num_epochs=args.epochs,
                     batch_size=batch_size,
                     k_val=args.kval,
                     multiplicative_factor=args.mf,
                     lr=args.lr,
                     bench_mode=args.bench,
                     save_logs=args.verbose,
                     save_mixup_datasets=False,
                     compare_features_only=False)
