import random
from pathlib import Path

import librosa
import numpy as np
import pandas as pd
import soundfile
from transformers import Wav2Vec2Processor, TFWav2Vec2Model


import cfg


def create():
    # set random seed
    random.seed(0)
    np.random.seed(0)

    # create metadata
    df = get_metadata()

    # create and save data
    save_data(df)

    # save embeddings
    save_embeddings(df)


def get_metadata():
    # load metadata
    df = pd.read_csv(cfg.path_urbansound8k_metadata)

    # create col fname
    df['fname'] = df['slice_file_name']

    # create label columns
    df['class_clean'] = df['class'].str.replace('_', '', regex=False)
    df_labels = pd.get_dummies(df['class_clean'], prefix=cfg.label_prefix).astype(int)
    df = pd.concat([df, df_labels], axis=1)

    # create subset column
    df['subset'] = np.where(df['fold'].isin([1, 2]), cfg.tag_evaluate, cfg.tag_unlabelled)

    return df

def save_data(df):

    # iterate over all samples in  df
    for index, row in df.iterrows():
        print(f'PREPROCESSING UrbanSound8K: Row {index} / {len(df)}')

        # get audio path
        audio_path = Path(cfg.path_urbansound8k, 'audio', f'fold{row["fold"]}', row["fname"])

        # read audio
        audio, sr = librosa.load(audio_path, sr=16000)

        # create 1s audio
        if len(audio) < 16000:
            total_pad = 16000 - len(audio)
            pad_left = total_pad // 2
            pad_right = total_pad - pad_left
            audio = np.pad(audio, (pad_left, pad_right), mode='constant')
        else:
            audio = audio[:16000]

        # save audio
        soundfile.write(Path(cfg.path_data, 'urbansound8k', 'data', row['fname']), audio, samplerate=sr)

    # save metadata
    df = df.drop(['slice_file_name', 'fsID', 'start', 'end', 'salience', 'fold', 'classID', 'class', 'class_clean'],
                 axis=1) # keep only fname, label, subset columns
    df = df[sorted(df.columns)] # sort columns
    numeric_cols = df.select_dtypes(include='number').columns # make label columns to int
    df[numeric_cols] = df[numeric_cols].astype(int) # make label columns to int
    df.to_csv(Path(cfg.path_data, 'urbansound8k', 'metadata.csv'), index=False) # save df


def save_embeddings(df):
    embeddings = []
    # load the model
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
    model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

    # create an empty batch
    batch = []
    # iterate over all samples in the df_label
    audio_path_old = ''
    for index, row in df.iterrows():
        print(f'PREPROCESSING UrbanSound8k: Row {index} / {len(df)}')
        # get audio path for current file
        audio_path = Path(cfg.path_data, 'urbansound8k', 'data', row['fname'])

        # load audio
        audio, sr = librosa.load(audio_path, sr=16000)

        # append audio to batch
        batch.append(audio)

        if index % 100 == 0 or index == len(df)-1:
            # Preprocess the audio
            batch_array = np.array(batch)
            input_values = processor(batch_array, sampling_rate=16000, return_tensors="tf").input_values

            # Extract embeddings, pool max across sequence length
            embedding = model(input_values).last_hidden_state
            embeddings_numpy = embedding.numpy()
            embeddings_max_pool = np.max(embeddings_numpy, axis=1)

            # save embedding
            embeddings.append(embeddings_max_pool)
            # reset batch
            batch = []

    final_embedding = np.concatenate(embeddings, axis=0)
    np.save(Path(cfg.path_data, 'urbansound8k', 'data_embedding.npy'), final_embedding)


if __name__ == '__main__':
    create()