
SAVE = False
DATA = "IEMOCAP"
SPLIT = "test"
DATA_S1 = "CommonVoice_LibriSpeech"
DATA_S2 = "IEMOCAP" 
LATENT_DIM = 128
MODEL_NAME = "wav2vec2-base" 
LAYER_S1 = "all" 
LAYER_S2 = "all" 
LAYER = "all"
LEARNING_RATE = 0.001

import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(sys.modules[__name__].__file__), "..")))

OBJECTIVE = "emotion"
LAYER_S1 = LAYER_S1 if LAYER_S1 == "all" else int(LAYER_S1)
LAYER_S2 = LAYER_S2 if LAYER_S2 in ["all", None] else int(LAYER_S2)
LEARNING_RATE = float(LEARNING_RATE)
BETA_S1 = "incremental" 
BETA_S2 = "incremental" 
SEED = 42
SELECTED_GPU = 0
DATA_PATH = f"{os.environ['HOME']}/Projects/acoustic-linguistic/directory/data/"
ALIGNMENT_BASE_PATH = f"{os.environ['HOME']}/Projects/acoustic-linguistic/directory/mfa/{DATA}/"
WEIGHTS_PATH = f"{os.environ['HOME']}/Projects/acoustic-linguistic/directory/analysis/{OBJECTIVE}/{DATA}/{SPLIT}/{MODEL_NAME}/"
SAVE_FIGURES_PATH = f"{os.environ['HOME']}/Projects/acoustic-linguistic/directory/figures/"


## Imports
import pickle
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from plotnine import *
import torch
from datasets import load_from_disk, Audio
from transformers import Wav2Vec2Processor, Wav2Vec2Model, HubertModel
from vib.vib import VIB, VIBConfig
from probing.probe import Probe, ProbeConfig
from utils import MODEL_NAME_MAPPER, PROCESSOR_MAPPER, NUM_CLASSES_MAPPER, EMOTION_LABEL_MAPPER, get_frame_boundaries, add_mfa
import IPython.display as ipd
import spacy
from textblob import TextBlob
import parselmouth
nlp = spacy.load('en_core_web_sm')


## GPU
if torch.cuda.is_available():     
    device = torch.device(f"cuda:{SELECTED_GPU}")
    print('We will use the GPU:', torch.cuda.get_device_name(SELECTED_GPU))
else:
    device = torch.device("cpu")
    print('No GPU available, using the CPU instead.')


def trim(x, desired_size):
    extra_dims = desired_size - len(x) 
    if extra_dims > 0:
        x = np.concatenate([x, np.repeat(x[-1], extra_dims)])
    elif extra_dims < 0:
        x = x[:desired_size]
    return x

def normalize_min_max(x):
    x = np.array(x)
    min_val = np.min(x)
    max_val = np.max(x)
    x = (x - min_val) / (max_val - min_val)
    return  x.tolist()

def normalize_sum_to_one(x):
    x = np.array(x)
    x = x / np.sum(x)
    return x.tolist()

# Load pre-trained model
processor = Wav2Vec2Processor.from_pretrained(PROCESSOR_MAPPER[MODEL_NAME])

# Load data
dataset = load_from_disk(f"{DATA_PATH}{DATA}")[SPLIT]
dataset = dataset.select_columns(['audio', 'transcription', 'emotion'])
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))


# load and integrate mfa alignments into the dataset
dataset = add_mfa(dataset, ALIGNMENT_BASE_PATH, SPLIT)

# integrate weights
# vib preds & attentions
postfix = f"_dim={LATENT_DIM}_layer={LAYER_S1}_{LAYER_S2}"
vib_preds = pd.read_pickle(f'{WEIGHTS_PATH}vib_preds{postfix}.pkl')
vib_attentions = pd.read_pickle(f'{WEIGHTS_PATH}vib_attentions{postfix}.pkl')
# model IGs
postfix = f"_layer={LAYER}"
igs = pd.read_pickle(f'{WEIGHTS_PATH}model_igs{postfix}.pkl')
postfix = f"_dim={LATENT_DIM}_layer={LAYER_S1}_{LAYER_S2}"
h_probing_preds = pd.read_pickle(f'{WEIGHTS_PATH}h_probe_preds{postfix}.pkl')
# add to dataset
dataset = dataset.add_column("vib_prediction", vib_preds['pred_2'])
dataset = dataset.add_column("Textual Attention", vib_attentions['textual_attention'])
dataset = dataset.add_column("Acoustic Attention", vib_attentions['acoustic_attention'])
dataset = dataset.add_column("Integrated Gradient", igs['ig'])
dataset = dataset.add_column("h_prediction", h_probing_preds['pred'])

# Run
updated_intervals = []
polarity_vectors = []
intensity_vectors = []
pitch_vectors = []
for ex in range(dataset.num_rows):
    # get frame size and audio time
    frame_size = len(dataset[ex]['Textual Attention'])
    audio_time = len(dataset[ex]['audio']['array']) / dataset[ex]['audio']['sampling_rate']

    # extract framewise polarity tags
    mfa_intervals = dataset[ex]['mfa_intervals']
    sent = " ".join([mfa_intervals[t]['word'] for t in range(len(mfa_intervals))])
    doc = nlp(sent)
    doc_index = 0
    interval = []
    polarities = np.zeros(frame_size)
    for t in range(len(mfa_intervals)):
        word = mfa_intervals[t]['word']
        start = mfa_intervals[t]['start']
        end = mfa_intervals[t]['end']

        while not(doc_index < len(doc) and word.startswith(doc[doc_index].text)):
            doc_index += 1

        tb = TextBlob(doc[doc_index].text)
        polarity = tb.sentiment.polarity
        
        s, e = get_frame_boundaries(start, end, frame_size, audio_time)
        polarities[s:e] = polarity 

        interval.append({'start_time': start, 'end_time': end, 'start_frame': s, 'end_frame': e, 'word': word, 'polarity': polarity})
        doc_index += 1

    # store text polarities
    updated_intervals.append(interval)
    polarity_vectors.append(np.abs(polarities))


    # extract framewise acoustic features
    time_step = audio_time / (frame_size+1) # milisecond
    sound = parselmouth.Sound(dataset[ex]['audio']['array'], dataset[ex]['audio']['sampling_rate'])

    # intensity
    intensity = sound.to_intensity(time_step=time_step)
    intensity_times, intensity_values = intensity.xs(), intensity.values.squeeze(0)
    intensity_values = trim(intensity_values, frame_size)
    
    # pitch
    pitch = sound.to_pitch(time_step=time_step)
    pitch_times, pitch_values = pitch.xs(), pitch.selected_array['frequency']
    pitch_values = trim(pitch_values, frame_size)

    # store acoustic features
    intensity_vectors.append(intensity_values)
    pitch_vectors.append(pitch_values)
    
# # update dataset
dataset = dataset.add_column("intervals", updated_intervals)
dataset = dataset.remove_columns(["mfa_intervals"])
dataset = dataset.add_column("Intensity", intensity_vectors)
dataset = dataset.add_column("Pitch", pitch_vectors)
dataset = dataset.add_column("Sentiment Polarity", polarity_vectors)


# compute dot product
features = ['Sentiment Polarity', 'Intensity', 'Pitch']
methods = ['Integrated Gradient', 'Acoustic Attention', 'Textual Attention']
records = []
for f in features:
    for m in methods:
        scores = []
        for ex in range(dataset.num_rows):
            # filter those without text polarity
            if not np.sum(dataset[ex]['Sentiment Polarity']) > 0:
                continue
            # filter those that are misclassified
            if dataset[ex]['emotion'] != dataset[ex]['vib_prediction'] or dataset[ex]['emotion'] != dataset[ex]['h_prediction']:
                continue
            a = dataset[ex][f]
            b = dataset[ex][m]
            if f in ['Intensity', 'Pitch']:
                a = np.abs(np.diff(a, append=a[-1]))
            a = normalize_min_max(a)
            b = normalize_sum_to_one(normalize_min_max(b))
            s = np.dot(a, b)
            scores.append(s) 
        records.append({'Feature': f, 'Method': m,
                        'Dot Product': np.mean(scores)})
records = pd.DataFrame(records)
records['Method'] = pd.Categorical(records['Method'], categories=methods)

color = "purple"
g = (ggplot(records, aes(x='Feature', y='Method', fill='Dot Product')) +
           geom_tile() +
           scale_fill_gradient(low="white", high=color) +
           theme_minimal() +
           theme(axis_text_x=element_text(rotation=45, ha='right'), text=element_text(size=10))
    )
print(g)

ggsave(g, f'{SAVE_FIGURES_PATH}/attribution_dot_product.pdf')