import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import librosa
import soundfile
from transformers import Wav2Vec2Processor, TFWav2Vec2Model

import cfg


def create():
    # set random seed
    random.seed(0)
    np.random.seed(0)

    # create metadata
    df = get_metadata()

    # create and save data
    save_data(df)

    # save embeddings
    save_embeddings(df)


def get_metadata():
    # init df
    dict = {}
    # sampling_rate (44.1 kHz)
    sampling_rate = 44100

    # iterate over speakers
    folders = os.listdir(cfg.path_carina)
    for folder in folders:
        # iterate over sentences from each speaker
        path_speaker = Path(cfg.path_carina, folder)
        files = os.listdir(path_speaker)
        files = {os.path.splitext(file)[0] for file in files}
        for file in files:
            with open(Path(path_speaker, file + '.par'), 'r', encoding='utf-8') as f:
                # iterate over all lines in the file
                for line in f:
                    c_line = line.strip()
                    if c_line.startswith('RET:'):
                        audio_samples = int(c_line.split(' ')[-1])
                        audio_sec = int(np.floor(audio_samples / sampling_rate))
                        # don´t use small audio samples
                        if audio_sec == 0:
                            continue
                        # Ensure intermediate keys exist
                        if folder not in dict:
                            dict[folder] = {}
                        if file not in dict[folder]:
                            dict[folder][file] = {}

                        for audio_start in range(audio_sec):
                            dict[folder][file][audio_start] = {}

                    elif c_line.startswith('MAU:'):
                        _, label_start, label_duration, _, label = c_line.split('\t')
                        label_without_accent = cfg.label_prefix + label.replace('"', '')
                        label_without_accent = label_without_accent.replace('%', '')
                        label_start_sec = int(label_start) / sampling_rate
                        label_end_sec = (int(label_start) + int(label_duration)) / sampling_rate

                        for start_sec in range(int(np.floor(label_start_sec)), min(int(np.ceil(label_end_sec)), audio_sec)):
                            if start_sec in dict[folder][file]:
                                dict[folder][file][start_sec][label_without_accent] = 1
        print(f'Preprocessing CARInA {folder}')
    rows = []
    for folder, files in dict.items():
        for file, start_sec in files.items():
            for start_sec, labels in start_sec.items():
                row = {'folder': folder, 'file': file, 'start_sec': start_sec}
                for label, value in labels.items():
                    row[label] = int(value)
                rows.append(row)
    df = pd.DataFrame(rows)
    # fill in 0s
    df = df.fillna(0)
    # sort df
    df = df.sort_values(by=['folder', 'file', 'start_sec'], ignore_index=True)

    # Loop through the label columns and drop those with a sum below 5000
    label_columns = [col for col in df.columns if col.startswith(cfg.label_prefix)]
    for col in label_columns:
        if df[col].sum() < 5000 or col == cfg.label_prefix + '<p:>':
            df = df.drop(columns=[col])

    # create subset column
    evaluation_current_fraction = 0
    evaluation_target_fraction = 0.2
    for speaker in df['folder'].unique():
        nr_samples_speaker = len(df[df['folder'] == speaker])
        if evaluation_current_fraction + nr_samples_speaker / len(df) < evaluation_target_fraction:
            evaluation_current_fraction = evaluation_current_fraction + nr_samples_speaker / len(df)
            speaker_tag = cfg.tag_evaluate
        else:
            speaker_tag = cfg.tag_unlabelled
        df.loc[df['folder'] == speaker, 'subset'] = speaker_tag

    # create fname column
    df['fname'] = df.apply(lambda row: f'{row["folder"]}_{row["file"]}_{row["start_sec"]}_{row["start_sec"] + 1}.wav',
                           axis=1)

    return df


def save_data(df):
    # iterate over all samples in the df_label
    audio_path_old = ''
    for index, row in df.iterrows():
        print(f'PREPROCESSING CARInA: Row {index} / {len(df)}')

        # get audio path for current file
        audio_path = Path(cfg.path_carina, row['folder'], row['file'] + '.wav')
        # load audio if new file
        if audio_path != audio_path_old:
            audio, sr = librosa.load(audio_path, sr=16000)
            audio_path_old = audio_path

        # crop audio
        audio_sample_start = row['start_sec'] * sr
        audio_sample_end = (row['start_sec'] + 1) * sr
        audio_row = audio[audio_sample_start:audio_sample_end]

        # save audio
        soundfile.write(Path(cfg.path_data, 'carina', 'data', row['fname']), audio_row, samplerate=sr)

    # save metadata
    df = df.drop(['folder', 'file', 'start_sec'], axis=1) # keep only fname, label, subset columns
    df = df[sorted(df.columns)] # sort columns
    numeric_cols = df.select_dtypes(include='number').columns # make label columns to int
    df[numeric_cols] = df[numeric_cols].astype(int) # make label columns to int
    df.to_csv(Path(cfg.path_data, 'carina', 'metadata.csv'), index=False) # save df


def save_embeddings(df):
    embeddings = []
    # load the model
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
    model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

    # create an empty batch
    batch = []
    # iterate over all samples in the df_label
    audio_path_old = ''
    for index, row in df.iterrows():
        print(f'PREPROCESSING CARInA: Row {index} / {len(df)}')
        # get audio path for current file
        audio_path = Path(cfg.path_data, 'carina', 'data', row['fname'])

        # load audio
        audio, sr = librosa.load(audio_path, sr=16000)

        # append audio to batch
        batch.append(audio)

        if index % 100 == 0 or index == len(df)-1:
            # Preprocess the audio
            batch_array = np.array(batch)
            input_values = processor(batch_array, sampling_rate=16000, return_tensors="tf").input_values

            # Extract embeddings, pool max across sequence length
            embedding = model(input_values).last_hidden_state
            embeddings_numpy = embedding.numpy()
            embeddings_max_pool = np.max(embeddings_numpy, axis=1)

            # save embedding
            embeddings.append(embeddings_max_pool)
            # reset batch
            batch = []

    final_embedding = np.concatenate(embeddings, axis=0)
    np.save(Path(cfg.path_data, 'carina', 'data_embedding.npy'), final_embedding)




if __name__ == '__main__':
    create()
