import os
import numpy as np
from tqdm import tqdm
from typing import List

from source.constants import DOTS_PATH


def _generate_data(
    n_samples: int = 82_000,
    pixels: int = 32,
    dot_range: List = list(range(10, 51)),
    min_dist: float = 3 * np.sqrt(2),
    var_range: List[float] = [0.5, 1.3],
    intensity_range: List[float] = [0.4, np.sqrt(2)],
    max_iters: int = 100,
    seed: int = 2357,
):
    samples = np.zeros((n_samples, pixels, pixels))
    labels = np.zeros((n_samples))

    rng = np.random.default_rng(seed)

    mean_range = [0, pixels]

    for n in tqdm(range(n_samples)):
        # select number of dots consecutively from dot_range, approximately balanced
        n_dots = dot_range[min(n // (n_samples // len(dot_range)), len(dot_range) - 1)]

        mean = rng.uniform(*mean_range, (n_dots, 2))
        var = rng.uniform(*var_range, n_dots)
        intensity = rng.uniform(*intensity_range, n_dots)

        for _ in range(max_iters):
            # Compute pairwise distances for all means
            distances = np.linalg.norm(mean[:, None, :] - mean[None, :, :], axis=2)
            np.fill_diagonal(distances, np.inf)  # Ignore self-distances

            # Identify points that violate the minimum distance constraint
            violating_pairs = np.where(distances < min_dist)

            if len(violating_pairs[0]) == 0:  # No violations, exit early
                break

            # Resample violating means in parallel
            unique_violating_indices = np.unique(violating_pairs[0])
            new_means = rng.uniform(*mean_range, (len(unique_violating_indices), 2))

            # Update violating means
            mean[unique_violating_indices] = new_means

        dots = [(mean[i], var[i], intensity[i]) for i in range(n_dots)]

        # render image
        x, y = np.meshgrid(np.arange(pixels), np.arange(pixels), indexing="ij")
        grid = np.stack((x, y), axis=-1)
        img = np.zeros((pixels, pixels))
        for dot in dots:
            mean, var, intensity = dot
            distances = np.linalg.norm(grid - mean, axis=-1)
            img += intensity * np.exp(-distances / var)
        img = np.clip(img, 0, 1)

        samples[n] = img
        labels[n] = n_dots

    return samples, labels


def generate_and_save_data(save_path: str = os.path.join(DOTS_PATH, "dots.npz")):
    os.makedirs(DOTS_PATH, exist_ok=True)

    samples, labels = _generate_data(
        n_samples=82_000,
        pixels=32,
        dot_range=list(range(10, 51)),
        min_dist=3 * np.sqrt(2),
        var_range=[0.5, 1.3],
        intensity_range=[0.4, np.sqrt(2)],
        max_iters=100,
        seed=2357,
    )

    # save as uint8
    samples *= 255
    samples = samples.astype(np.uint8)
    labels = labels.astype(np.uint8)

    np.savez(save_path, samples=samples, labels=labels)


def _load_data():
    data = np.load(os.path.join(DOTS_PATH, "dots.npz"))
    samples = data["samples"].astype(np.float32)
    labels = data["labels"].astype(np.int32)

    # normalize samples
    samples /= 255
    # unsqueeze samples
    samples = samples[:, None, :, :]

    labels = labels.astype(np.float32) / 50

    return samples, labels


def get_standard_data(seed=2357):
    samples, labels = _load_data()

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))
    rng.shuffle(indices)

    train_indices = indices[:60_000]
    val_indices = indices[60_000:70_000]
    test_indices = indices[70_000:]

    train_samples, train_labels = samples[train_indices], labels[train_indices]
    val_samples, val_labels = samples[val_indices], labels[val_indices]
    test_samples, test_labels = samples[test_indices], labels[test_indices]

    return (
        train_samples,
        train_labels,
        val_samples,
        val_labels,
        test_samples,
        test_labels,
    )


def get_gap_data(seed=2357):
    samples, labels = _load_data()

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))

    test_indices = indices[29_000:51_000]
    remaining_indices = np.concatenate([indices[:29_000], indices[51_000:]])

    rng.shuffle(remaining_indices)
    val_indices = remaining_indices[:10_000]
    train_indices = remaining_indices[10_000:]

    train_samples, train_labels = samples[train_indices], labels[train_indices]
    val_samples, val_labels = samples[val_indices], labels[val_indices]
    test_samples, test_labels = samples[test_indices], labels[test_indices]

    return (
        train_samples,
        train_labels,
        val_samples,
        val_labels,
        test_samples,
        test_labels,
    )


def get_tail_data(seed=2357):
    samples, labels = _load_data()

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))

    test_indices = indices[59_000:]
    remaining_indices = indices[:59_000]

    rng.shuffle(remaining_indices)
    val_indices = remaining_indices[:10_000]
    train_indices = remaining_indices[10_000:]

    train_samples, train_labels = samples[train_indices], labels[train_indices]
    val_samples, val_labels = samples[val_indices], labels[val_indices]
    test_samples, test_labels = samples[test_indices], labels[test_indices]

    return (
        train_samples,
        train_labels,
        val_samples,
        val_labels,
        test_samples,
        test_labels,
    )


def get_gap_tail_data(seed=2357):
    samples, labels = _load_data()

    # split into train, val and test
    rng = np.random.default_rng(seed)
    indices = np.arange(len(samples))

    test_indices = np.concatenate([indices[19_000:35_000], indices[59_000:]])
    remaining_indices = np.concatenate([indices[:19_000], indices[35_000:59_000]])

    rng.shuffle(remaining_indices)
    val_indices = remaining_indices[:10_000]
    train_indices = remaining_indices[10_000:]

    train_samples, train_labels = samples[train_indices], labels[train_indices]
    val_samples, val_labels = samples[val_indices], labels[val_indices]
    test_samples, test_labels = samples[test_indices], labels[test_indices]

    return (
        train_samples,
        train_labels,
        val_samples,
        val_labels,
        test_samples,
        test_labels,
    )
