import os
import torch
import random
import torchaudio
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvggish import vggish, vggish_input

__all__ = ['ESC50']

class ESC50(Dataset):
    def __init__(self, root, split, name='esc50', model='', nways=5, shots=5, query_shots=5, num_tasks=1, niterations=10000, batch_size=4, visualize=False, download=True):
        self.root = root
        self.split = split
        self.nways = nways
        self.shots = shots
        self.num_tasks = num_tasks
        self.query_shots = query_shots
        self.visualize = visualize

        if download:
            if not os.path.exists(root):
                os.makedirs(root)
                os.system('git clone git@github.com:karolpiczak/ESC-50.git {}'.format(root))

        if (model == 'mlti_protonet') or ('settaskinterpolator' in model):
            self.interp = True if self.split=='train' else False
        else:
            self.interp = False
        self.name = '{}'.format(name).upper()
        self.datalen = niterations * batch_size

        self.data = self.load_esc50()
        
    def create_task(self, random_classes=None):
        if random_classes is None:
            random_classes = torch.randperm(len(self.data))[:self.nways]
        
        support, query = [], []
        for random_class in random_classes:
            rand_index = torch.randperm(len(self.data[random_class]))[:self.shots + self.query_shots]
            support.append(torch.cat([self.data[random_class][i].contiguous().view(-1).unsqueeze(0) for i in rand_index[:self.shots]], dim=0).unsqueeze(0))

            query.append(torch.cat([self.data[random_class][i].contiguous().view(-1).unsqueeze(0) for i in rand_index[self.shots:]], dim=0).unsqueeze(0))
        support = torch.cat(support, dim=0) #ways x shots x fdim
        query = torch.cat(query, dim=0)     #ways x shots x fdim

        slabel = torch.arange(0, self.nways).view(self.nways, 1).repeat(1, self.shots).long()
        qlabel = slabel[:, :1].clone().repeat(1, self.query_shots)
        return support, slabel, query, qlabel
    
    def __getitem__(self, index):
        if not self.interp:
            support, slabel, query, qlabel = self.create_task()
        else:
            support, slabel, query, qlabel = [], [], [], []
            for i in range(self.num_tasks):
                support_, slabel_, query_, qlabel_ = self.create_task()
                support.append(support_.unsqueeze(0)); slabel.append(slabel_.unsqueeze(0))
                query.append(query_.unsqueeze(0)); qlabel.append(qlabel_.unsqueeze(0))
            support, slabel = torch.cat(support, dim=0), torch.cat(slabel, dim=0)
            query, qlabel = torch.cat(query, dim=0), torch.cat(qlabel, dim=0)
        return support, slabel, query, qlabel
    
    def get_batch(self, batch_size):
        supports, slabels, querys, qlabels = [], [], [], []
        for i in range(batch_size):
            support, slabel, query, qlabel = self.__getitem__(index=i)
            supports.append(support.unsqueeze(0)); slabels.append(slabel.unsqueeze(0))
            querys.append(query.unsqueeze(0)); qlabels.append(qlabel.unsqueeze(0))
        return torch.cat(supports, dim=0), torch.cat(slabels, dim=0), torch.cat(querys, dim=0), torch.cat(qlabels, dim=0)

    def __len__(self):
        return self.datalen   

    def load_esc50(self):
        metadata = os.path.join(self.root, 'meta/esc50.csv')
        
        with open(metadata, 'r') as f:
            lines = f.readlines()[1:]
        
        if not os.path.exists(os.path.join(self.root, 'embeddings')):
            os.makedirs(os.path.join(self.root, 'embeddings'))
            
            embedding_model = vggish()
            embedding_model.eval()

            for line in tqdm(lines, total=len(lines), ncols=75, leave=False):
                filename, fold, target, category, esc10, src_file, take = line.strip().split(',')
                audiopath = os.path.join(self.root, 'audio/', filename)
                example = vggish_input.wavfile_to_examples(audiopath)
                embeddings = embedding_model.forward(example)
                torch.save(embeddings, os.path.join(self.root, 'embeddings', filename.split('/')[-1].replace('.wav', '.pth')))

        data = [[] for _ in range(50)]
        for line in lines:
            filename, fold, target, category, esc10, src_file, take = line.strip().split(',')
            embeddingpath = os.path.join(self.root, 'embeddings', filename.split('/')[-1].replace('.wav', '.pth'))
            data[int(target)].append(torch.load(embeddingpath).detach())

        if self.split == 'train':
            return data[:20]
        elif self.split == 'valid':
            return data[20:35]
        else:
            return data[35:]

if __name__ == '__main__':
    t = ESC50(root='ESC50', split='train')
