import h5py
import torch
import random
import numpy as np
import torchvision
import torchaudio.transforms as T
from torchvision import transforms
from torch.utils.data import Dataset

class Lighting:
    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = eigval
        self.eigvec = eigvec
    
    def __call__(self, x):
        if self.alphastd == 0:
            return x
        alpha = torch.randn(3) * self.alphastd
        bias = self.eigvec @ (alpha * self.eigval)
        out = (x + bias[:, None, None]).clip(0, 255).byte()
        return out

def get_imagenet_transform(transform_type):
    if transform_type == 'basic':
        imagenet_mean = [0.5, 0.5, 0.5]
        imagenet_std = [0.5, 0.5, 0.5]
        crop_scale = 0.08
    elif transform_type == 'inception':
        imagenet_mean = [0.5, 0.5, 0.5]
        imagenet_std = [0.5, 0.5, 0.5]
        crop_scale = 0.08
    elif transform_type == 'mobile':
        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]
        crop_scale = 0.25
    
    imagenet_eigval = torch.tensor([0.2175, 0.0188, 0.0045])
    imagenet_eigvec = torch.tensor([
        [-0.5675, 0.7192, 0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948, 0.4203]
    ])
    
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(224, (crop_scale, 1.0)),
        transforms.ColorJitter(0.4, 0.4, 0.4),
        Lighting(0.1, imagenet_eigval, imagenet_eigvec),
        transforms.RandomHorizontalFlip(),
        transforms.ConvertImageDtype(torch.float32),
        transforms.Normalize(imagenet_mean, imagenet_std)
    ])
    val_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ConvertImageDtype(torch.float32),
        transforms.Normalize(imagenet_mean, imagenet_std)
    ])
    
    return {'train': train_tf, 'val': val_tf}

class ImageNetPair(Dataset):
    def __init__(self, img_pth, mode='train'):
        self.fns = list((img_pth/mode).glob('*/*.JPEG'))
        self.tf = get_imagenet_transform('basic')[mode]
    
    def __len__(self):
        return len(self.fns)
    
    def read_image(self):
        try:
            fn_1 = str(self.fns[random.randint(0, self.__len__()-1)])
            fn_2 = str(self.fns[random.randint(0, self.__len__()-1)])
            img_1 = torchvision.io.read_image(fn_1)
            img_2 = torchvision.io.read_image(fn_2)
        except:
            img_1, img_2 = self.read_image()
        return img_1, img_2
    
    def __getitem__(self, idx):
        img_1, img_2 = self.read_image()        
        if img_1.shape[0] == 1:
            img_1 = torch.cat([img_1]*3, 0)
        if img_2.shape[0] == 1:
            img_2 = torch.cat([img_2]*3, 0)
        if img_1.shape[0] == 4:
            img_1 = img_1[:3]
        if img_2.shape[0] == 4:
            img_2 = img_2[:3]
        img_1 = self.tf(img_1)
        img_2 = self.tf(img_2)
        return img_1, img_2

class ImageAudioPair(Dataset):
    def __init__(self, img_pth, audio_pth, max_audio_len=10, mode='train'):
        self.mode = mode
        # image dataset
        self.fns = list((img_pth/mode).glob('*/*.JPEG'))
        self.tf = get_imagenet_transform('basic')[mode]

        # audio dataset
        self.audio_pth = audio_pth/'stft.h5'
        self.max_audio_len = max_audio_len
        self.audio_dataset = None
        self.time_n = None
        self.sz = 224

    def read_image(self):
        try:
            fn = str(self.fns[random.randint(0, self.__len__()-1)])
            img = torchvision.io.read_image(fn)
        except:
            img = self.read_image()
        return img
    
    def read_audio(self):
        if self.audio_dataset is None:
            self.audio_dataset = h5py.File(self.audio_pth, 'r')['stft']
            self.time_n = self.audio_dataset.shape[2]
            self.train_p = 0.9

        train_time_n = int(self.time_n * self.train_p)
        range1 = train_time_n-self.sz
        range2 = self.time_n-self.sz

        stfts = []
        for _ in range(self.max_audio_len):
            if self.mode == 'train':
                st_pos = random.randint(0, range1)
            elif self.mode == 'val':
                st_pos = random.randint(range1+1, range2)
            stfts.append(self.audio_dataset[0, :, st_pos:st_pos+self.sz])

        # (max_audio_len, 2, 224, 224)
        return np.stack(stfts, 0).transpose(0, 3, 1, 2)
    
    def __len__(self):
        return len(self.fns)
    
    def __getitem__(self, idx):
        img = self.read_image()
        if img.shape[0] == 1:
            img = torch.cat([img]*3, 0)
        if img.shape[0] == 4:
            img = img[:3]
        img = self.tf(img)
        spec = self.read_audio()
        return spec, img

class SpecDataset(Dataset):
    def __init__(self, pth, max_len=10):
        self.pth = pth
        self.max_len = max_len
        self.dataset = None
        self.time_n = None
        self.n_mels = 224
    
    def __getitem__(self, idx):
        if self.dataset is None:
            self.dataset = h5py.File(self.pth, 'r')['spectrogram']
            self.time_n = self.dataset.shape[2]
        
        st_pos = random.randint(0, self.time_n-self.n_mels*self.max_len)
        end_pos = st_pos+self.n_mels*self.max_len
        spec = self.dataset[0, :, st_pos:end_pos]
        spec = np.stack(np.split(spec, self.max_len, -1), -1)
        return spec / 8000 * 2 - 1
    
    def __len__(self):
        iter_per_epoch = 10000
        return iter_per_epoch