import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torchvision
import torchvision.transforms as transforms
from datetime import datetime

# ---- Logging ----
class Logger(object):
    def __init__(self, log_path):
        self.terminal = sys.stdout
        self.log = open(log_path, "a")
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()
    def flush(self):
        self.terminal.flush()
        self.log.flush()

def sample_pairwise_distances(images, num_samples=1000, device='cpu', batch_size=100):
    n = images.shape[0]
    idx1 = np.random.randint(0, n, size=num_samples)
    idx2 = np.random.randint(0, n, size=num_samples)
    mask = idx1 != idx2
    idx1, idx2 = idx1[mask], idx2[mask]
    dists = []
    for start in range(0, len(idx1), batch_size):
        end = min(start + batch_size, len(idx1))
        img1 = images[idx1[start:end]].reshape(end-start, -1)
        img2 = images[idx2[start:end]].reshape(end-start, -1)
        batch_dists = np.linalg.norm(img1 - img2, axis=1)
        dists.append(batch_dists)
    return np.concatenate(dists)

def sample_k_random_vectors(dim, radius, k):
    random_matrix = np.random.randn(k, dim)
    norms = np.linalg.norm(random_matrix, axis=1, keepdims=True)
    directions = random_matrix / norms
    random_radii = np.random.rand(k, 1) * radius
    vectors = directions * random_radii
    return vectors

def mixup_images_fn(original_images, k, mf, universal_radius):
    n, c, h, w = original_images.shape
    dim = c * h * w
    mixup_indices = []
    mixup_weights = []
    mixup_images = np.zeros_like(original_images)
    W = np.zeros((n, n), dtype=np.float32)
    for i in range(n):
        indices = np.random.choice([j for j in range(n) if j != i], k-1, replace=False)
        indices = np.concatenate(([i], indices))
        weights = np.random.dirichlet([1.0]*k)
        # Ensure target image weight <= 1/k
        max_weight = 1.0 / k
        if weights[0] > max_weight:
            excess = weights[0] - max_weight
            weights[0] = max_weight
            rest = weights[1:]
            rest_sum = rest.sum()
            if rest_sum > 0:
                rest = rest * ((1.0 - max_weight) / rest_sum)
            else:
                rest = np.full_like(rest, (1.0 - max_weight) / (k-1))
            weights[1:] = rest
        mixup_indices.append(indices)
        mixup_weights.append(weights)
        noises = sample_k_random_vectors(dim, mf * universal_radius, k-1)
        perturbed = [original_images[idx].reshape(-1) + noise for idx, noise in zip(indices[1:], noises)]
        mix_img = np.zeros(dim, dtype=np.float32)
        for j, (idx, coeff) in enumerate(zip(indices, weights)):
            if j == 0:
                img = original_images[idx].reshape(-1)
            else:
                img = perturbed[j-1]
            mix_img += coeff * img
            W[i, idx] += coeff
        mixup_images[i] = mix_img.reshape(c, h, w)
    return mixup_images, W

def recover_images_gd(mixup_images, W, ridge_lambda=1e-2, num_epochs=1000, lr=1e-2, lambda_tv=1e-5, lambda_l2=1e-4, verbose=False, device='cpu'):
    n, c, h, w = mixup_images.shape
    mixup_images_torch = torch.tensor(mixup_images, device=device, dtype=torch.float32)
    W_torch = torch.tensor(W, device=device, dtype=torch.float32)
    recovered = torch.nn.Parameter(torch.randn(n, c, h, w, device=device))
    def total_variation_loss(img):
        return torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) + \
               torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))
    def l2_norm_loss(img):
        return (img**2).mean()
    optimizer = torch.optim.Adam([recovered], lr=lr)
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        recon = torch.matmul(W_torch, recovered.view(n, -1)).view(n, c, h, w)
        loss_recon = ((recon - mixup_images_torch)**2).mean()
        loss_tv = total_variation_loss(recovered)
        loss_l2 = l2_norm_loss(recovered)
        loss = loss_recon + ridge_lambda * loss_l2 + lambda_tv * loss_tv
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            recovered.clamp_(-2.5, 2.5)
        if verbose and (epoch % 100 == 0 or epoch == num_epochs-1):
            print(f"GD Epoch {epoch}, Loss: {loss.item():.4f}, Recon: {loss_recon.item():.4f}, TV: {loss_tv.item():.4f}")
    return recovered.detach().cpu().numpy()

def get_dataset(name, input_size=224):
    if name == "MNIST":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        class_names = [str(i) for i in range(10)]
    elif name == "CIFAR10":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        class_names = dataset.classes
    elif name == "CIFAR100":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        class_names = dataset.classes
    else:
        raise ValueError("Unknown dataset name")
    return dataset, class_names

def main():
    mf_values = [1,2,4,8,16,32]
    dataset_names = ["MNIST", "CIFAR10", "CIFAR100"]
    outdir = "results"
    os.makedirs(outdir, exist_ok=True)
    log_path = os.path.join(outdir, f"all_datasets.log")
    sys.stdout = Logger(log_path)
    print(f"Experiment started at {datetime.now()}")
    print(f"mf_values: {mf_values}")
    print(f"Datasets: {dataset_names}")

    num_samples = 32   # Number of images per dataset for mixup/recovery and distance computation
    input_size = 224
    np.random.seed(42)
    torch.manual_seed(42)

    for k in [4, 6]:
        print(f"\n==================== k = {k} ====================")
        avg_distances_per_dataset = {}  # dataset_name -> {mf: avg_distance}
        for dset_idx, dset_name in enumerate(dataset_names):
            print(f"\n=== Processing {dset_name} ===")
            dataset, class_names = get_dataset(dset_name, input_size=input_size)
            indices = np.random.choice(len(dataset), num_samples, replace=False)
            subset_images = torch.stack([dataset[i][0] for i in indices]) # [N, C, H, W]
            images_np = subset_images.numpy()
            n = images_np.shape[0]
            # Compute universal radius using 1000 random pairs
            sampled_dists = sample_pairwise_distances(images_np, num_samples=1000, device='cpu', batch_size=100)
            universal_radius = np.mean(sampled_dists)
            print(f"Universal radius (average over 1000 pairs): {universal_radius:.4f}")
            avg_distances = {}
            for mf in mf_values:
                print(f"  -- mf={mf} --")
                mixup_imgs, W = mixup_images_fn(images_np, k, mf, universal_radius)
                recovered_imgs = recover_images_gd(
                    mixup_imgs, W, ridge_lambda=1e-2, num_epochs=1000, lr=1e-2,
                    lambda_tv=1e-5, lambda_l2=1e-4, verbose=False, device='cpu'
                )
                # Compute average Euclidean distance over the batch
                dist = np.linalg.norm(images_np.reshape(num_samples, -1) - recovered_imgs.reshape(num_samples, -1), axis=1).mean()
                print(f"    Average Euclidean distance (recovered vs original): {dist:.4f}")
                avg_distances[mf] = dist
            avg_distances_per_dataset[dset_name] = avg_distances

        # Plot for this k
        plt.figure(figsize=(7,5))
        for dset_name in dataset_names:
            mf_list = sorted(avg_distances_per_dataset[dset_name].keys())
            dist_list = [avg_distances_per_dataset[dset_name][mf] for mf in mf_list]
            plt.plot(mf_list, dist_list, marker='o', label=dset_name)
        plt.xlabel("mf (noise multiplier)")
        plt.ylabel("Average Euclidean distance (original vs recovered)")
        plt.title(f"Average Recovery Error vs. mf (k={k})")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        fig_path = os.path.join(outdir, f"average_recovery_vs_mf_k{k}.png")
        plt.savefig(fig_path, dpi=200)
        print(f"Plot saved to {fig_path}")
        plt.show()

    print(f"Experiment finished at {datetime.now()}")

if __name__ == "__main__":
    main()
