import numpy as np
from PIL import Image

import torch
import torchvision.transforms as transforms

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize, CenterCrop

import skimage

from imageio.v3 import imread
from skimage.color import rgb2gray
import skimage.filters as skfilters
from skimage import feature

import siren

import os
import h5py

def arr2imggrid(arr):
    arr_square = np.r_[
        np.c_[arr, arr[:, ::-1]],
        np.c_[arr[::-1, :], arr[::-1, ::-1]]
    ]
    return Image.fromarray(np.r_[
        np.c_[arr_square, arr_square],
        np.c_[arr_square, arr_square]
    ])

def reflection_pad(img, shiftx, shifty):
    arr = np.asarray(img)
    outx, outy = arr.shape
    outx += 2 * shiftx
    outy += 2 * shifty

    iny, inx, *_ = arr.shape
    iny -= 1; inx -= 1
    yoffs, xoffs = (outy - iny) // 2, (outx - inx) // 2

    Y, X = np.ogrid[:outy, :outx]
    # transformation logic is essentially contained in line below
    out = arr[np.abs((Y - yoffs + iny) % (2*iny) - iny), np.abs((X - xoffs + inx) % (2*inx) - inx)]

    return Image.fromarray(out)


def periodic_shift(img, shiftx, shifty):
    width, height = img.size
    padded = reflection_pad(img, abs(shiftx), abs(shifty))
    arr = np.asarray(padded)
    
    startx, endx = (2 * abs(shiftx), width + 2 * abs(shiftx)) if shiftx < 0 else (0, width)
    starty, endy = (2 * abs(shifty), height + 2 * abs(shifty)) if shifty > 0 else (0, height)

    return Image.fromarray(arr[starty:endy, startx:endx])
    

def remove_duplicates(coords1, coords2, thresh=5e-3):
    diffs = (coords1[:, :, None] - coords2.T[None, :, :]).pow(2).sum(1).sqrt()
    inds_to_keep = ~(diffs < thresh).any(0)

    return inds_to_keep


image_set = {
    'astro' : rgb2gray(skimage.data.astronaut()),
    'camera' : skimage.data.camera(),
    'coins' : skimage.data.coins(),
    'coffee' : rgb2gray(skimage.data.coffee()),
    'cat' : rgb2gray(skimage.data.cat()),
    'brick' : skimage.data.brick(),
}

_IMDIR = '/its/home/gc453/sussex/inr_eda/data/imagenet-sample-images'

def load_image(name, laplace=False):
    if name in image_set:
        raw_im = image_set[name]
    else:
        inames = [i for i in os.listdir(_IMDIR) if i.endswith('JPEG') and name in i]
        if len(inames) == 0:
            raise ValueError(f'{name} not found!')
        im_name = inames[0]
        raw_im = rgb2gray(imread(os.path.join(_IMDIR, im_name)))

    if laplace:
        return skfilters.laplace(raw_im)
    else:
        return raw_im

class ImageDataset():
    def __init__(self, im_name, sidelength=64, laplace=False, do_remove_duplicates=True):
        self.im_name = im_name
        self.sidelength = sidelength
        self.laplace = laplace
        self.raw_img = self.get_raw_img(sidelength).squeeze()
        self.img = self.get_img_tensor(sidelength)
        self.super_img = self.get_img_tensor(4 * sidelength)

        self.dset = siren.ImageFitting(img=self.img)
        super_dset = siren.ImageFitting(img=self.super_img)

        if do_remove_duplicates:
            inds_to_keep = remove_duplicates(self.dset.coords, super_dset.coords)
            self.super_dset = siren.ImageFitting(coords=super_dset.coords[inds_to_keep], pixels=super_dset.pixels[inds_to_keep])
        else:
            self.super_dset = super_dset

        self.grad_mag = skfilters.farid(self.raw_img)
        self.edges = feature.canny(self.raw_img).astype('int').flatten()

    def get_raw_img(self, sidelength):
        #img = Image.fromarray(image_set[self.im_name])
        img = load_image(self.im_name, laplace=self.laplace)
        img = Image.fromarray(img)
        transform = Compose([
            Resize(sidelength),
            CenterCrop(sidelength),
            ToTensor()
        ])
        img = transform(img)
        return img.numpy()

    def get_img_tensor(self, sidelength):
        #img = Image.fromarray(image_set[self.im_name])
        img = load_image(self.im_name, laplace=self.laplace)
        img = Image.fromarray(img)
        transform = Compose([
            Resize(sidelength),
            CenterCrop(sidelength),
            ToTensor(),
            Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
        ])
        img = transform(img)
        return img




def save_ntk_comp(path, ntk_comp):
    with h5py.File(path, 'w') as hdf:
        for i, item in enumerate(ntk_comp):
            group = hdf.create_group(f'item_{i}')
            ins_group = group.create_group('ins')
            outs_group = group.create_group('outs')
            
            # Store each array in 'ins' and 'outs' with compression
            for j, arr in enumerate(item['ins']):
                ins_group.create_dataset(f'array_{j}', data=arr, compression='gzip')
            for k, arr in enumerate(item['outs']):
                outs_group.create_dataset(f'array_{k}', data=arr, compression='gzip')


def save_ntk_series(path, ntk_series):
    with h5py.File(path, 'w') as hdf:
        for i, arr in enumerate(ntk_series):
            hdf.create_dataset(f'array_{i}', data=arr, compression='gzip')


def save_exp_dict(path, exp_dict):
    keys = exp_dict.keys()

    #save exp_dict
    torch.save(
        {
            k : v for k, v in exp_dict.items()
            if k not in {'ntk', 'ntk_comp'}
        },
        os.path.join(path, 'stats.pth')
    )

    save_ntk_comp(
        os.path.join(path, 'ntk_comp.h5'), exp_dict['ntk_comp']
    )


def load_ntk_comp(path):
    fpath = os.path.join(path, 'ntk_comp.h5')
    loaded_data = []
    with h5py.File(fpath, 'r') as hdf:
        for item_key in sorted(hdf.keys(), key=lambda x : int(x.split('_')[-1])):
            item_group = hdf[item_key]
            item = {
                'ins': [item_group['ins'][arr_key][:] for arr_key in sorted(item_group['ins'].keys())],
                'outs': [item_group['outs'][arr_key][:] for arr_key in sorted(item_group['outs'].keys())]
            }
            loaded_data.append(item)
    
    return loaded_data