import os
import pandas as pd
import json
import copy
import codecs


class CNNReader:
    def __init__(self, fpath='data/cnn', fname='model_annotations.aligned.json'):

        self.df = pd.read_json(os.path.join(fpath, fname), lines=True)

    def average_dict(self, input_dic):
        averaged_dict = dict()
        for key, _ in input_dic[0].items():
            averaged_dict[key] = 0
        for current_dic in input_dic:
            for key, value in current_dic.items():
                averaged_dict[key] += value
        for key, _ in input_dic[0].items():
            averaged_dict[key] = averaged_dict[key] / len(input_dic)
        return averaged_dict

    def get_gen_ref_met(self):
        sent_generated = []
        ref_sent = []
        metric_dicts = {'coherence': [], 'consistency': [], 'fluency': [], 'relevance': []}
        for row_id, row in self.df.iterrows():
            sent_generated += [row['decoded']] * len(row['references'])
            assert len(row['references']) == 11
            ref_sent += row['references']
            curr_avg_dic = self.average_dict(row['expert_annotations'])
            for key, value in curr_avg_dic.items():
                metric_dicts[key].append(value)
        return sent_generated, ref_sent, metric_dicts


class CocoReader:
    def __init__(self, fpath='coco', fname='captions_val2014.json'):  # human_all_captions_coco
        self.fpath = os.path.join('data', fpath)
        self.fname = fname
        # https://cocodataset.org/#captions-leaderboard
        self.coco_results_dict = {'Human': [0.638, 0.675, 4.836, 3.428, 0.352],
                                  'Google': [0.273, 0.317, 4.107, 2.742, 0.233],
                                  'MSR': [0.268, 0.322, 4.137, 2.662, 0.234],
                                  'Montreal/Toronto': [0.262, 0.272, 3.932, 2.832, 0.197],
                                  'MSR Captivator': [0.250, 0.301, 4.149, 2.565, 0.233],
                                  'Berkeley LRCN': [0.246, 0.268, 3.924, 2.786, 0.204],
                                  'm-RNN': [0.223, 0.252, 3.897, 2.595, 0.202],
                                  'Nearest Neighbor': [0.216, 0.255, 3.801, 2.716, 0.196],
                                  'PicSOM': [0.202, 0.250, 3.965, 2.552, 0.182],
                                  'Brno University': [0.194, 0.213, 3.079, 3.482, 0.154],
                                  'm-RNN (Baidu/ UCLA)': [0.190, 0.241, 3.831, 2.548, 0.195],
                                  'MIL': [0.168, 0.197, 3.349, 2.915, 0.159],
                                  'MLBL': [0.167, 0.196, 3.659, 2.420, 0.156],
                                  'NeuralTalk': [0.166, 0.192, 3.436, 2.742, 0.147],
                                  'ACVT': [0.154, 0.190, 3.516, 2.599, 0.155],
                                  'Tsinghua Bigeye': [0.100, 0.146, 3.510, 2.163, 0.116],
                                  'Random': [0.007, 0.020, 1.084, 3.247, 0.013]}
        self.path_dict = {
            # 'MSR_Captivator':'MSR Captivator',
            'OriolVinyals': 'Google',
            'human': 'Human',
            # 'NearestNeighbor':'Nearest Neighbor',
            # 'mmitchell':'MSR',
            'rakshithShetty': 'PicSOM',
            'Q.Wu': 'ACVT',
            'karpathy': 'NeuralTalk',
            'myamaguchi': 'MIL',
            'ryank': 'MLBL',
            'mRNN_share.JMao': 'm-RNN',
            'kolarmartin': 'Brno University',
            'jeffdonahue': 'Berkeley LRCN',
            'junhua.mao': 'm-RNN (Baidu/ UCLA)',
            'TsinghuaBigeye': 'Tsinghua Bigeye',
            'kelvin_xu': 'Montreal/Toronto'
        }

    def process_gen(self):
        dirs = os.listdir(self.fpath)
        all_dict_results = {}
        for dir_ in dirs:
            if dir_ in self.path_dict.keys():
                path = os.path.join(self.fpath, dir_)
                onlyfiles = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
                file_to_open = [f for f in onlyfiles if (('results.json' in f) and ('val' in f))]
                with open(os.path.join(self.fpath, dir_, file_to_open[0]), 'r') as file:
                    refs = json.load(file)
                ref_cleaned = {}
                for ref in refs:
                    ref_cleaned[ref['image_id']] = ref['caption']
                all_dict_results[self.path_dict[dir_]] = ref_cleaned
        return all_dict_results

    def get_references(self):
        with open(os.path.join(self.fpath, self.fname), 'r') as file:
            ref = json.load(file)
        keys = list(set([i['image_id'] for i in ref['annotations']]))
        cleaned_ref = {}
        for key in keys:
            cleaned_ref[int(key)] = []
        for annot in ref['annotations']:
            cleaned_ref[int(annot['image_id'])].append(annot['caption'])
        # for key, value in ref.items():
        #    simple_key = int(key.split('_')[-1].replace('.jpg', ''))
        #    cleaned_ref[simple_key] = value
        return cleaned_ref

    def get_gen_ref_met(self):
        sent_generated = self.process_gen()
        ref_sent = self.get_references()
        metric_dicts = self.coco_results_dict
        return sent_generated, ref_sent, metric_dicts


class TAC2009Reader():
    def __init__(self, fpath='tac2009', fname="tac.09.mds.gen.resp-pyr"):
        self.fpath = os.path.join('data', fpath)
        self.fname = fname

    def normalize_responsiveness(self, dataset):
        max_resp = 0.
        for k, v in dataset.items():
            for annot in v['annotations']:
                if annot['responsiveness'] > max_resp:
                    max_resp = annot['responsiveness']
        for k, v in dataset.items():
            for annot in v['annotations']:
                annot['responsiveness'] /= float(max_resp)
        return dataset

    def merge_datasets(self, lst_datasets):
        merged_dataset = {}
        for dataset in lst_datasets:
            merged_dataset.update(copy.deepcopy(dataset))
        return merged_dataset

    def get_gen_ref_met(self):
        def load_json(filename):
            filepath = os.path.join(filename)
            with codecs.open(filepath, 'r', encoding='utf-8') as f:
                return json.loads(f.read())

        tac_09_mds_gen_resp_pyr = self.normalize_responsiveness(load_json(os.path.join(self.fpath, self.fname)))
        # resp_data = self.merge_datasets([tac_09_mds_gen_resp_pyr])
        pyr_data = self.merge_datasets([tac_09_mds_gen_resp_pyr])

        pyr_data = dict(list(pyr_data.items()))
        # resp_data = dict(list(resp_data.items()))

        # human_scores = ['pyr_score', 'responsiveness']
        # dataset = [list(pyr_data.items()), list(resp_data.items())]
        return list(pyr_data.items()), None, None  # TODO : pas uniforme mais c'est complique


class Data2TextReader:
    def __init__(self, fpath='.', fname='data2text_gen.csv'):
        self.fpath = os.path.join('data', fpath)
        self.fname = fname

    def get_gen_ref_met(self):
        df_data2text = pd.read_csv(os.path.join(self.fpath, self.fname))
        df_data2text = df_data2text[df_data2text.system == 'LOLS']
        processed_dict = {
            'BAGEL': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []},
            'SFHOT': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []},
            'SFRES': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []}
        }
        naturalness = 0
        quality = 0
        informativeness = 0
        past_sys_ref = df_data2text[df_data2text.index == 0].sys_ref.values[0]
        past_orig_ref = df_data2text[df_data2text.index == 0].orig_ref.values[0]
        past_dataset = df_data2text[df_data2text.index == 0].dataset.values[0]
        past_system = df_data2text[df_data2text.index == 0].system.values[0]
        print('Processing')
        for index, row in df_data2text.iterrows():
            if (index + 1) % 3 == 0:
                informativeness += row['informativeness']
                naturalness += row['naturalness']
                quality += row['quality']
                naturalness = naturalness / 18  # to scale from 1-6 yo 0-1
                quality = quality / 18
                informativeness = informativeness / 18
                # TODO : update :
                processed_dict[row['dataset']]['metric_dicts']['informativeness'].append(informativeness)
                processed_dict[row['dataset']]['metric_dicts']['naturalness'].append(naturalness)
                processed_dict[row['dataset']]['metric_dicts']['quality'].append(quality)

                processed_dict[row['dataset']]['sys_ref'].append(row['sys_ref'])
                processed_dict[row['dataset']]['orig_ref'].append(row['orig_ref'])
                processed_dict[row['dataset']]['system'].append(row['system'])
                # reset :
                naturalness = 0
                quality = 0
                informativeness = 0
                if index + 1 < len(df_data2text):
                    past_sys_ref = df_data2text[df_data2text.index == (index + 1)].sys_ref.values[0]
                    past_orig_ref = df_data2text[df_data2text.index == (index + 1)].orig_ref.values[0]
                    past_dataset = df_data2text[df_data2text.index == (index + 1)].dataset.values[0]
                    past_system = df_data2text[df_data2text.index == (index + 1)].system.values[0]
            else:
                assert past_sys_ref == row['sys_ref']
                assert past_orig_ref == row['orig_ref']
                assert past_dataset == row['dataset']
                assert past_system == row['system']

                informativeness += row['informativeness']
                naturalness += row['naturalness']
                quality += row['quality']
        # TODO : second loop to group
        final_processed_dict = {
            'BAGEL': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []},
            'SFHOT': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []},
            'SFRES': {'metric_dicts': {'informativeness': [], 'naturalness': [], 'quality': []}, 'sys_ref': [],
                      'orig_ref': [],
                      'system': []}
        }
        for key_dataset, value_dataset in processed_dict.items():
            previous_element = value_dataset['sys_ref'][0]
            count = -1  # to put it at the beginnig of the loop
            curr_metric_dicts = {'informativeness': [], 'naturalness': [], 'quality': []}
            runing_sys_ref, runing_ori_ref, runing_system = [], [], []
            for element in value_dataset['sys_ref']:
                count += 1

                if element == previous_element:
                    runing_sys_ref.append(element)
                    runing_ori_ref.append(value_dataset['orig_ref'][count])
                    runing_system.append(value_dataset['system'][count])
                    for key, value in curr_metric_dicts.items():
                        curr_metric_dicts[key].append(value_dataset['metric_dicts'][key][count])

                else:
                    # Reset
                    final_processed_dict[key_dataset]['sys_ref'].append(runing_sys_ref)
                    final_processed_dict[key_dataset]['orig_ref'].append(runing_ori_ref)
                    final_processed_dict[key_dataset]['system'].append(runing_system)
                    # Append curr_metric_dict
                    for key, value in curr_metric_dicts.items():
                        final_processed_dict[key_dataset]['metric_dicts'][key].append(value)
                    runing_sys_ref, runing_ori_ref, runing_system = [], [], []
                    curr_metric_dicts = {'informativeness': [], 'naturalness': [], 'quality': []}
                    # Update reset
                    runing_sys_ref.append(element)
                    runing_ori_ref.append(value_dataset['orig_ref'][count])
                    runing_system.append(value_dataset['system'][count])
                    for key, value in curr_metric_dicts.items():
                        curr_metric_dicts[key].append(value_dataset['metric_dicts'][key][count])

                previous_element = element
        return final_processed_dict, None, None


readers = {
    'cnn': CNNReader,
    'coco': CocoReader,
    'tac2009': TAC2009Reader,
    'hotel': Data2TextReader
}

if __name__ == '__main__':
    readers = {
        'cnn': CNNReader,
        'coco': CocoReader,
        'tac2009': TAC2009Reader,
        'hotel': Data2TextReader
    }

    reader = readers['cnn']()
    sent_generated, ref_sent, metric_dicts = reader.get_gen_ref_met()
