import torch

from dataset_utils import load_dataset, load_model
import random
import os

os.chdir("../.")


OUTPUT_DIR_PATH= ''

# set random seed for reproducibility
random.seed(42)

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def _divide_batches_into_divisions(per_class_batches):
    """
    per_class_batches: dict where keys 0–9 map to lists of batches for that class
                       (each batch is a list of sample-indices).
    Returns: list of divisions, each division is a list of 10 batches
             (one batch per class).
    """
    # pick out the class keys
    classes = sorted(k for k in per_class_batches.keys() if isinstance(k, int))
    # how many batches each class has (should all be equal)
    num_batches = len(per_class_batches[classes[0]])
    if any(len(per_class_batches[c]) != num_batches for c in classes):
        raise ValueError("All classes must have the same number of batches")
    # build divisions by taking the i-th batch from each class
    divisions = [
        [per_class_batches[c][i] for c in classes]
        for i in range(num_batches)
    ]
    return divisions


def _extract_per_class_correct_samples(dataset, full_model, test_data, samples_per_class, num_classes, device='cpu'):
    # for each class, collect `samples_per_class` correctly classified samples
    all_indices = list(range(len(test_data)))
    random.shuffle(all_indices)
    per_class_samples = {d: [] for d in range(num_classes)}
    full_model.eval()

    for idx in all_indices:
        X, label = test_data[idx]
        if dataset == 'mnist':
            X = X.view(1, -1).to(device)
        if dataset in ['cifar10', 'cifar10-small', 'gtsrb']:
            X = X.unsqueeze(0).to(device)  # add batch dimension

        with torch.no_grad():
            output = full_model(X)
            predicted = torch.argmax(output, dim=1).item()

        if predicted == label and len(per_class_samples[label]) < samples_per_class:
            per_class_samples[label].append(idx)
        # stop early if all classes are filled
        if all(len(per_class_samples[d]) == samples_per_class for d in range(num_classes)):
            break
    return per_class_samples


def extract_mixed_correct_samples(full_model, test_data, num_samples=20, device='cpu'):
    # randomly sample `num_samples` correctly classified examples from `test_data`, regardless of label
    all_indices = list(range(len(test_data)))
    random.shuffle(all_indices)
    correct_samples = []
    full_model.eval()
    for idx in all_indices:
        X, label = test_data[idx]
        X = X.unsqueeze(0).to(device)
        with torch.no_grad():
            output = full_model(X)
            predicted = torch.argmax(output, dim=1).item()
        if predicted == label:
            correct_samples.append(idx)
            if len(correct_samples) >= num_samples:
                break
    return correct_samples



def _split_batches(per_digit_samples, batch_size=3, num_classes=10):
    """
    For each digit, split its samples into batches of size `batch_size`.
    Returns a dict: {digit: [[idx1, idx2, idx3], ...]}
    """
    batches = {}
    for d in range(num_classes):
        samples = per_digit_samples[d]
        batches[d] = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]
    batches["all_batches"] = [batch for digit_batches in batches.values() for batch in digit_batches]
    return batches


def save_dict_to_file(data, file_path):
    """Save a dictionary to a file without extra spacing."""
    import json
    with open(file_path, 'w') as f:
        json.dump(data, f, separators=(',', ':'))

# --- NEW: Taxinet helpers ---

def extract_taxinet_close_samples(full_model, test_data, num_samples=20, device='cpu', max_abs_err=0.2):
    """
    Randomly sample up to `num_samples` examples from Taxinet `test_data` whose
    absolute prediction error is <= `max_abs_err`.
    Returns a list of indices.
    """
    all_indices = list(range(len(test_data)))
    random.shuffle(all_indices)
    picked = []
    full_model.eval()
    with torch.no_grad():
        for idx in all_indices:
            X, y = test_data[idx]        # X: (C,H,W), y: (1,)
            X1 = X.unsqueeze(0).to(device)
            pred = full_model(X1).squeeze().item()
            yv = y.squeeze().item()
            if abs(pred - yv) <= max_abs_err:
                picked.append(idx)
                if len(picked) >= num_samples:
                    break
    return picked


def _split_list_into_batches(indices, batch_size=3):
    """Split a flat list of indices into batches of size `batch_size`."""
    return [indices[i:i+batch_size] for i in range(0, len(indices), batch_size)]


if __name__ == '__main__':

    # --- MNIST  ---
    # --- Getting 100 batches, 10 from each digit  ---

    # Load MNIST dataset and model
    train_gen, test_data, test_gen = load_dataset(dataset_name='mnist', batch_size=100)
    full_model, model_path = load_model(dataset_name='mnist', device=device)
    full_model.eval()
    per_class_samples_mnist = _extract_per_class_correct_samples('mnist', full_model, test_data, samples_per_class=30, num_classes=10, device=device)
    mnist_batches = _split_batches(per_class_samples_mnist, batch_size=3)
    mnist_divisions = _divide_batches_into_divisions(mnist_batches)

    save_dict_to_file(mnist_batches, f'{OUTPUT_DIR_PATH}/mnist_batches.json')
    save_dict_to_file({'divisions': mnist_divisions}, f'{OUTPUT_DIR_PATH}/mnist_divisions.json')

    #  --- CIFAR-10 ---
    train_gen, test_data, test_gen = load_dataset(dataset_name='cifar10-small', batch_size=100)
    full_model, model_path = load_model(dataset_name='cifar10-small', device=device)
    full_model.eval()
    per_class_samples_cifar10 = _extract_per_class_correct_samples('cifar10-small', full_model, test_data, samples_per_class=30, num_classes=10, device=device)
    cifar10_batches = _split_batches(per_class_samples_cifar10, batch_size=3)
    cifar10_divisions = _divide_batches_into_divisions(cifar10_batches)

    save_dict_to_file(cifar10_batches, f'{OUTPUT_DIR_PATH}/cifar10_batches.json')
    save_dict_to_file({'divisions': cifar10_divisions}, f'{OUTPUT_DIR_PATH}/cifar10_divisions.json')


    # --- GTSRB ---
    train_gen, test_data, test_gen = load_dataset(dataset_name='gtsrb', batch_size=100)
    full_model, model_path = load_model(dataset_name='gtsrb', device=device)
    full_model.eval()
    per_class_samples_gtsrb = _extract_per_class_correct_samples('gtsrb', full_model, test_data,
                                                                 samples_per_class=9, num_classes=43, device=device)
    gtsrb_batches = _split_batches(per_class_samples_gtsrb, batch_size=3, num_classes=43)
    gtsrb_divisions = _divide_batches_into_divisions(gtsrb_batches)

    save_dict_to_file(gtsrb_batches, f'{OUTPUT_DIR_PATH}/gtsrb_batches.json')
    save_dict_to_file({'divisions': gtsrb_divisions}, f'{OUTPUT_DIR_PATH}/gtsrb_divisions.json')

    # --- Taxinet: pick samples with abs(pred - gold) <= 0.2 ---
    train_gen, test_data, test_gen = load_dataset(dataset_name='taxinet', batch_size=100)
    full_model, model_path = load_model(dataset_name='taxinet', device=device)
    full_model.eval()
    taxinet_samples = extract_taxinet_close_samples(full_model, test_data, num_samples=300, device=device, max_abs_err=0.2)
    # keep original flat list for backward-compatibility
    save_dict_to_file({'samples': taxinet_samples, 'max_abs_err': 0.2}, f'{OUTPUT_DIR_PATH}/taxinet_samples.json')
    # NEW: also save batches of size 3
    taxinet_batches = _split_list_into_batches(taxinet_samples, batch_size=3)
    save_dict_to_file({'batches': taxinet_batches, 'batch_size': 3, 'max_abs_err': 0.2}, f'{OUTPUT_DIR_PATH}/taxinet_batches.json')
