import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torchvision
import torchvision.transforms as transforms
from datetime import datetime

# ---- Logging ----
class Logger(object):
    def __init__(self, log_path):
        self.terminal = sys.stdout
        self.log = open(log_path, "a")
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()
    def flush(self):
        self.terminal.flush()
        self.log.flush()

def sample_pairwise_distances(images, num_samples=1000, device='cpu', batch_size=100):
    n = images.shape[0]
    idx1 = np.random.randint(0, n, size=num_samples)
    idx2 = np.random.randint(0, n, size=num_samples)
    mask = idx1 != idx2
    idx1, idx2 = idx1[mask], idx2[mask]
    dists = []
    for start in range(0, len(idx1), batch_size):
        end = min(start + batch_size, len(idx1))
        img1 = images[idx1[start:end]].reshape(end-start, -1)
        img2 = images[idx2[start:end]].reshape(end-start, -1)
        batch_dists = np.linalg.norm(img1 - img2, axis=1)
        dists.append(batch_dists)
    return np.concatenate(dists)

def sample_k_random_vectors(dim, radius, k):
    random_matrix = np.random.randn(k, dim)
    norms = np.linalg.norm(random_matrix, axis=1, keepdims=True)
    directions = random_matrix / norms
    random_radii = np.random.rand(k, 1) * radius
    vectors = directions * random_radii
    return vectors

def mixup_images_fn(original_images, k, mf, universal_radius):
    n, c, h, w = original_images.shape
    dim = c * h * w
    mixup_indices = []
    mixup_weights = []
    mixup_images = np.zeros_like(original_images)
    W = np.zeros((n, n), dtype=np.float32)
    for i in range(n):
        indices = np.random.choice([j for j in range(n) if j != i], k-1, replace=False)
        indices = np.concatenate(([i], indices))
        weights = np.random.dirichlet([1.0]*k)
        # Ensure target image weight <= 1/k
        max_weight = 1.0 / k
        if weights[0] > max_weight:
            excess = weights[0] - max_weight
            weights[0] = max_weight
            rest = weights[1:]
            rest_sum = rest.sum()
            if rest_sum > 0:
                rest = rest * ((1.0 - max_weight) / rest_sum)
            else:
                rest = np.full_like(rest, (1.0 - max_weight) / (k-1))
            weights[1:] = rest
        mixup_indices.append(indices)
        mixup_weights.append(weights)
        noises = sample_k_random_vectors(dim, mf * universal_radius, k-1)
        perturbed = [original_images[idx].reshape(-1) + noise for idx, noise in zip(indices[1:], noises)]
        mix_img = np.zeros(dim, dtype=np.float32)
        for j, (idx, coeff) in enumerate(zip(indices, weights)):
            if j == 0:
                img = original_images[idx].reshape(-1)
            else:
                img = perturbed[j-1]
            mix_img += coeff * img
            W[i, idx] += coeff
        mixup_images[i] = mix_img.reshape(c, h, w)
    return mixup_images, W

def recover_images_gd(mixup_images, W, ridge_lambda=1e-2, num_epochs=1000, lr=1e-2, lambda_tv=1e-5, lambda_l2=1e-4, verbose=False, device='cpu'):
    n, c, h, w = mixup_images.shape
    mixup_images_torch = torch.tensor(mixup_images, device=device, dtype=torch.float32)
    W_torch = torch.tensor(W, device=device, dtype=torch.float32)
    recovered = torch.nn.Parameter(torch.randn(n, c, h, w, device=device))
    def total_variation_loss(img):
        return torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) + \
               torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))
    def l2_norm_loss(img):
        return (img**2).mean()
    optimizer = torch.optim.Adam([recovered], lr=lr)
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        recon = torch.matmul(W_torch, recovered.view(n, -1)).view(n, c, h, w)
        loss_recon = ((recon - mixup_images_torch)**2).mean()
        loss_tv = total_variation_loss(recovered)
        loss_l2 = l2_norm_loss(recovered)
        loss = loss_recon + ridge_lambda * loss_l2 + lambda_tv * loss_tv
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            recovered.clamp_(-2.5, 2.5)
        if verbose and (epoch % 100 == 0 or epoch == num_epochs-1):
            print(f"GD Epoch {epoch}, Loss: {loss.item():.4f}, Recon: {loss_recon.item():.4f}, TV: {loss_tv.item():.4f}")
    return recovered.detach().cpu().numpy()

def get_dataset(name, input_size=224):
    if name == "MNIST":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        class_names = [str(i) for i in range(10)]
    elif name == "CIFAR10":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        class_names = dataset.classes
    elif name == "CIFAR100":
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        class_names = dataset.classes
    else:
        raise ValueError("Unknown dataset name")
    return dataset, class_names

def unnormalize(img):
    # Undo normalization for plotting
    mean = np.array([0.485, 0.456, 0.406]).reshape(3,1,1)
    std = np.array([0.229, 0.224, 0.225]).reshape(3,1,1)
    return np.clip(img * std + mean, 0, 1)

def main():
    mf_values = [1,2,4,8,16,32]
    dataset_names = ["MNIST", "CIFAR10", "CIFAR100"]
    outdir = "results"
    os.makedirs(outdir, exist_ok=True)
    log_path = os.path.join(outdir, f"all_datasets.log")
    sys.stdout = Logger(log_path)
    print(f"Experiment started at {datetime.now()}")
    print(f"mf_values: {mf_values}")
    print(f"Datasets: {dataset_names}")

    num_samples = 32 
    input_size = 224
    np.random.seed(42)
    torch.manual_seed(42)

    k = 4  # Fixed k=4

    # For visualization: store best (original, recovered) pairs for each dataset and mf
    best_images = {dset: {mf: None for mf in mf_values} for dset in dataset_names}

    for dset_idx, dset_name in enumerate(dataset_names):
        print(f"\n=== Processing {dset_name} ===")
        dataset, class_names = get_dataset(dset_name, input_size=input_size)
        indices = np.random.choice(len(dataset), num_samples, replace=False)
        subset_images = torch.stack([dataset[i][0] for i in indices]) # [N, C, H, W]
        images_np = subset_images.numpy()
        n = images_np.shape[0]
        # Compute universal radius using 1000 random pairs
        sampled_dists = sample_pairwise_distances(images_np, num_samples=1000, device='cpu', batch_size=100)
        universal_radius = np.mean(sampled_dists)
        print(f"Universal radius (average over 1000 pairs): {universal_radius:.4f}")

        for mf in mf_values:
            print(f" -- mf={mf} --")
            mixup_imgs, W = mixup_images_fn(images_np, k, mf, universal_radius)
            recovered_imgs = recover_images_gd(
                mixup_imgs, W, ridge_lambda=1e-2, num_epochs=1000, lr=1e-2,
                lambda_tv=1e-5, lambda_l2=1e-4, verbose=False, device='cpu'
            )
            # Compute per-image Euclidean distances
            dists = np.linalg.norm(images_np.reshape(num_samples, -1) - recovered_imgs.reshape(num_samples, -1), axis=1)
            avg_dist = np.mean(dists)
            avg_idx = np.argmin(np.abs(dists - avg_dist))
            best_images[dset_name][mf] = (images_np[avg_idx], recovered_imgs[avg_idx])
            print(f" Average recovery distance: {dists[avg_idx]:.4f} (target average: {avg_dist:.4f})")


        # ---- Visualization ----
    fig, axes = plt.subplots(6, len(mf_values), figsize=(2.5*len(mf_values), 12))
    for dset_idx, dset_name in enumerate(dataset_names):
        for mf_idx, mf in enumerate(mf_values):
            orig, rec = best_images[dset_name][mf]
            orig_img = unnormalize(orig)
            rec_img = unnormalize(rec)
            row1 = 2*dset_idx
            row2 = 2*dset_idx+1
            axes[row1, mf_idx].imshow(np.transpose(orig_img, (1,2,0)))
            axes[row1, mf_idx].axis('off')
            axes[row2, mf_idx].imshow(np.transpose(rec_img, (1,2,0)))
            axes[row2, mf_idx].axis('off')
            if row1 == 0:
                axes[row1, mf_idx].set_title(f"mf={mf}", fontsize=12)

    # Add only "Original" and "Recovered" labels on the left, centered for each row
    for row in range(6):
        label = "Original" if row % 2 == 0 else "Recovered"
        axes[row, 0].set_ylabel(label, fontsize=13, fontweight='bold', rotation=0, labelpad=40, va='center')

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.05, hspace=0.15)
    plt.savefig(os.path.join(outdir, "best_recovery_grid.png"), dpi=200)
    print(f"Best recovery grid saved to {os.path.join(outdir, 'best_recovery_grid.png')}")
    plt.show()


    print(f"Experiment finished at {datetime.now()}")

if __name__ == "__main__":
    main()
