import torch
from torchvision import datasets, transforms

import os

from qtorch import CTYPE
from qtorch.quantumstate.measurements import getZDistribution

def get_first_mnist_digits(img_size:int=8,
                               mnist_path:str='./data',
                               download:bool=False)->torch.Tensor:
    """
    Loads the MNIST dataset, resizes images to (img_size, img_size), normalizes 
    them to [0, 1], and returns a tensor containing the first image of each 
    digit (0-9).
    """
    # Transformation: Resize and normalize to [0, 1]
    transform = transforms.Compose([
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor()
    ])

    # Load the MNIST training set
    mnist = datasets.MNIST(root=mnist_path, train=True, download=download, transform=transform)

    # Dictionary to store first image of each digit
    first_images = {}
    
    for img, label in mnist:
        if label not in first_images:
            first_images[label] = img[0]
        if len(first_images) == 10:
            break

    # Stack the images into a single tensor: shape (10, img_size, img_size)
    sorted_images = [first_images[i] for i in range(10)]
    return torch.stack(sorted_images)

def get_amplitude_encoding(images:torch.Tensor)->torch.Tensor:
    assert images.dim() == 3
    assert images.shape[1] == images.shape[2]
    num_images = images.shape[0]
    img_size = images.shape[1]
    flat = images.view(num_images, -1)
    # print(flat.shape)
    norms = flat.norm(p=2,dim=1,keepdim=True)
    # print(norms.shape)
    return (flat / norms).to(CTYPE)

def get_image_from_amplitude_encoding(state:torch.Tensor)->torch.Tensor:
    N = state.shape[-1]
    img_size = 2**((N-1).bit_length()//2)
    return torch.sqrt(getZDistribution(state)).reshape(img_size,img_size)


if __name__ == '__main__':
    # Example usage:
    images_tensor = get_first_mnist_digits(img_size:=8)
    print(images_tensor.shape)
    states = get_amplitude_encoding(images_tensor)
    

    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(1, 10, figsize=(12, 2))
    for i in range(10):
        axes[i].imshow(states[i].reshape([img_size,img_size]).imag, cmap='gray')
        axes[i].set_title(str(i))
        # axes[i].axis('off')
    plt.tight_layout()
    plt.show()

def make_ref_images(data_path:str, mnist_path:str, image_size:int
                    )->tuple[torch.Tensor,str]:
    try:
        remake_file = False
        img_path = os.path.join(data_path,f'ref_images_{image_size}.pt')
        ref_images:torch.Tensor = torch.load(img_path, weights_only=True)
        assert ref_images.shape[-1] == image_size
    except FileNotFoundError as e:
        print(f"File `{data_path}` not found, creating new images")
        remake_file = True
    except AssertionError as e:
        print(f"Expected image size to be {image_size}, `{data_path}` has size {ref_images.shape[-1]}. Making new images and overwriting file.")
        remake_file = True
    finally:
        if remake_file:
            if not os.path.exists(data_path):
                os.makedirs(data_path)
            ref_images = get_first_mnist_digits(image_size, mnist_path,
                                                download=True)
            torch.save(ref_images, img_path)
        print('Finished making/loading images')
    return ref_images,img_path

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Script to run the macro-'
                                     'search experiment for the image '
                                     'initialization problem.')
    
    
    parser.add_argument('--data-path', type=str, default='./data/', 
                        help='The file containing the reference images')
    parser.add_argument('--mnist-path', type=str, default='./data/',
                        help='Where to load the MNIST dataset from')
    parser.add_argument('--image-size', '-i', type=int, default=16,
                        help='Height of the resized images in pixels')
    args = parser.parse_args()
    
    make_ref_images(args.data_path, args.mnist_path, args.image_size)
