import torch.utils.data
import torchvision
import torchaudio
import os
import numpy as np
import bz2
import scipy.io
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
import pandas as pd
import torch.nn.functional as F
import random

import re

class TAU(Dataset):
    def __init__(self, root, is_train=True, domains='A', sr=16000, n_mels=256, device = 'cuda'):
        self.root = root
        self.is_train = is_train
        self.device = torch.device(device)
        
        if is_train is True: # train
            split_file = 'fold1_train.csv'
        elif is_train is False: # val
            split_file = 'fold1_evaluate.csv'
        else: # test
            split_file = 'fold1_test.csv'
        
        split_path = os.path.join(root, 'evaluation_setup', split_file)
        if split_file == 'fold1_test.csv':
            df_split = pd.read_csv(split_path,sep=r"\s+",header=None,names=["filename"])
        else:
            df_split = pd.read_csv(split_path,sep=r"\s+",header=None,names=["filename", "scene_label"])
        df_split["device"] = (df_split["filename"].str.extract(r'-([abc]|s\d)\.wav$', flags=re.IGNORECASE))[0].str.upper()
        self.devices = [d.strip().upper() for d in domains.split(',')]
        df_split = df_split[df_split["device"].isin(self.devices)]
        
        
        if split_file == 'fold1_test.csv':
            meta_path = os.path.join(root, 'meta.csv')
            meta = pd.read_csv(meta_path, sep=r"\s+", engine='python')[["filename", "scene_label"]]
            df_split = df_split.merge(meta, on="filename", how="left")
        
        self.file_list = df_split["filename"].tolist()
        self.labels = df_split["scene_label"].tolist()
        
        self.label_set = sorted(set(self.labels))
        self.label_idx = {l: i for i, l in enumerate(self.label_set)}

        # mel spectrogram trnansformation
        self.mel_tf = MelSpectrogram(
            sample_rate=sr, n_fft=2080, win_length=int(sr * 0.13), hop_length=int(sr * 0.03),
            n_mels=n_mels, f_min = 50, f_max = 8000
        )
        # log sclae (log-mel spectrogram)
        self.log_tf = AmplitudeToDB()
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, idx):
        wav_path = os.path.join(self.root, self.file_list[idx])
        wav, sr = torchaudio.load(wav_path)
        # print(sr, self.mel_tf.sample_rate) # 44100 16000
        
        
        if sr != self.mel_tf.sample_rate: 
            wav = torchaudio.functional.resample(wav, sr, self.mel_tf.sample_rate)
        
        if self.is_train is True:
            max_shift = int(1.5*self.mel_tf.sample_rate)
            shift = random.randint(-max_shift, max_shift)
            wav = torch.roll(wav, shifts=shift, dims=1)
            y = self.label_idx[self.labels[idx]]
            return wav, torch.tensor(y, dtype=torch.long)
        
        # log mel spectrogram
        mel = self.mel_tf(wav)
        mel_db = self.log_tf(mel)
        # mel_db = mel
        
        T = 350 # time frames
        cur_len = mel_db.size(-1)
        if cur_len < T: # pad
            mel_db = F.pad(mel_db, (0, T - cur_len))
        elif cur_len > T: # crop
            mel_db = mel_db[:, :, :T]
        
        # mel_db.shape() # ([1, 256, 400])
        y = self.label_idx[self.labels[idx]]
        return mel_db, torch.tensor(y, dtype=torch.long)



class CombiningDataset(torch.utils.data.Dataset):

    def __init__(self, datasets):
        self.datasets = datasets

    def __len__(self):
        return sum([len(dataset) for dataset in self.datasets])

    def __getitem__(self, index):
        dataset_idx = 0
        cur = 0
        while cur + len(self.datasets[dataset_idx]) <= index:
            cur += len(self.datasets[dataset_idx])
            dataset_idx += 1
        item_idx = index - cur
        return self.datasets[dataset_idx].__getitem__(item_idx)


class _BaseDataset(torch.utils.data.Dataset):

    def __init__(self, root, is_train=True, transform=None):
        super().__init__()
        self.root = root
        self.is_train = is_train
        self.transform = transform
        self.images = []
        self.labels = []
        self.extract_images_labels()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        x = Image.fromarray(self.images[index])
        y = int(self.labels[index])
        if self.transform is not None:
            x = self.transform(x)
        return x, y

    def extract_images_labels(self):
        raise NotImplementedError



def get_datasets(task, root, domains, is_train):
    if task == 'PACS':
        if is_train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop((224, 224), (224, 224)),
                transforms.RandomHorizontalFlip(0.5),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        datasets = [torchvision.datasets.ImageFolder(os.path.join(root, domain), transform=transform)
                    for domain in domains.split(',')]
    elif task == 'CIFAR10-C':
        preprocess = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([0.5] * 3, [0.5] * 3)])
        if is_train:
            transform = transforms.Compose(
                [transforms.RandomHorizontalFlip(),
                 transforms.RandomCrop(32, padding=4),
                 preprocess])
            datasets = [torchvision.datasets.CIFAR10(root, train=is_train, transform=transform, download=True)]
        else:
            CORRUPTIONS = [
                ['gaussian_noise', 'shot_noise', 'impulse_noise'],
                ['defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur'],
                ['snow', 'frost', 'fog', 'brightness'],
                ['contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
            ]
            transform = preprocess
            datasets = [torchvision.datasets.CIFAR10(root, train=is_train, transform=transform, download=True)
                        for _ in CORRUPTIONS * 5]
            all_data = []
            all_label = []
            label = np.load(os.path.join(root, 'CIFAR-10-C', 'labels.npy')).reshape((5, 10000))
            for corruptions in CORRUPTIONS:
                data = []
                for corruption in corruptions:
                    x = np.load(os.path.join(root, 'CIFAR-10-C', corruption + '.npy')).reshape((5, 10000, 32, 32, 3))
                    data.append(x)
                all_data.append(np.concatenate(data, axis=1))
                all_label.append(np.concatenate([label for _ in corruptions], axis=1))
            for idx, dataset in enumerate(datasets):
                level_idx = idx // 4
                domain_idx = idx % 4
                dataset.data = all_data[domain_idx][level_idx]
                dataset.targets = torch.LongTensor(all_label[domain_idx][level_idx])
    
    
    elif task == "TAU":
        datasets = [TAU(root, is_train=is_train, domains=d.strip()) for d in domains.split(',')]

    else:
        raise NotImplementedError
    return CombiningDataset(datasets) if is_train else datasets



            