import numpy as np
import os
import pandas as pd
from natsort import natsorted

## Getting groung truth labels and concept vectors for training

surveys_dir = '../../AMT_Survey/surveys'

surveys = [pd.read_csv(os.path.join(surveys_dir, f), skiprows=1, names=['Id', 'Label'], usecols=range(2))
           for f in natsorted(os.listdir(surveys_dir)) if f.endswith('csv')]
surveys_df = pd.concat(surveys, axis=0).reset_index(drop=True)
aggregation_functions = {'Label': 'first'}
survey_labels = surveys_df.groupby(surveys_df['Id']).aggregate(aggregation_functions)

load_file = pd.read_pickle(r'../Extract_Concepts/20210525_lg_mpcg_ids_and_pam_by_cummulative_mi.pkl')
concept_ids = pd.DataFrame({'Id':list(load_file['unique_ids']), 'Concepts':list(load_file['merged_pam'])})

labels_df = pd.merge(survey_labels, concept_ids, on='Id', how='right')
labels_df_filtered = labels_df[labels_df['Label']!='none']
labels_df_filtered = labels_df_filtered[labels_df_filtered['Label']!=' it could be called a strike because the pitch landed in the strike zone before being hit']
labels_df_filtered = labels_df_filtered.reset_index(drop=True)

print(labels_df_filtered['Label'].value_counts())

labels_df_filtered.to_pickle("labels_df_filtered_100.pkl")


## Representative Concept based on frequency

concepts = pd.read_csv('../Extract_Concepts/20210525_lg_concept_groupings_by_cummulative_mi.csv')
max_concepts = concepts.sort_values('freq', ascending=False).drop_duplicates(['final_id']).sort_values('final_id')
max_concepts = max_concepts[max_concepts['final_id']!=-1].reset_index(drop=True)
max_concepts.to_csv('concepts_100.csv')


## Map Video Ids and Survey Explanations

ids = []
labels = []
explanations = []

for f in natsorted(os.listdir(surveys_dir)):
    if f.endswith('csv'):
        survey = open(os.path.join(surveys_dir,f), "r")
        txt = survey.readlines()
        for i, line in enumerate(txt):
            if i == 0:
                head_tokens = line.split(',')
                if head_tokens[0] != 'elapsed':
                    raise ValueError(f"Header not valid with:\n\t{head_tokens}")
                continue
                
            try:
                line = line.strip()
                line_tokens = line.split(',')
                id_ = str(line_tokens[0])
                label = line_tokens[1]
                text = ','.join(line_tokens[2:]).strip()
                ids.append(id_)
                labels.append(label)
                explanations.append(text)
                
            except:
                print(f)

df = pd.DataFrame(list(zip(ids, labels, explanations)),
               columns =['Id', 'Label', 'Explanation'])

aggregation_functions = {'Label': 'first', 'Explanation':'first'}
survey_exp = df.groupby(df['Id']).aggregate(aggregation_functions)
exp_df = pd.merge(survey_exp, concept_ids, on='Id', how='right')
exp_df = exp_df[['Id','Explanation']]

exp_df.to_pickle("Explanations.pkl")