from utils.basic_utils import load_jsonl, load_json, save_jsonl, save_json
import os
from tqdm import tqdm


class Processor:
    def __init__(self, opt):
        super(Processor, self).__init__()
        self.idx_counter = 0
        self.dset_name = opt.dset_name
        self.phase = opt.phase
        self.root_dir = opt.root_dir
        self.thd = opt.thd
        self.fps = 1

    def _add_negative_query(self, data_list):
        for data1 in data_list:
            intra_neg_qids = []
            intra_pos_qids = []

            for data2 in data_list:
                if data1['qid'] == data2['qid']:
                    continue

                if self.check_if_negative_query(data1['relevant_windows'][0], data2['relevant_windows'][0]):
                    intra_neg_qids.append(data2['qid'])
                else:
                    intra_pos_qids.append(data2['qid'])

            data1.update(intra_neg_query_ids=intra_neg_qids,
                         intra_pos_query_ids=intra_pos_qids)
        return data_list

    def check_if_negative_query(self, window1, window2):
        st1, ed1 = window1[0], window1[1]
        st2, ed2 = window2[0], window2[1]

        if self.get_iou(st1, ed1, st2, ed2) < self.thd:
            return True
        else:
            return False

    def process_data(self, data, phase):
        results = []
        for vid, data_item in tqdm(data.items(), total=len(data), desc='process {} {}'.format(self.dset_name, phase)):
            if 'duration' not in data_item.keys():
                if self.dset_name == 'charades-CD':
                    duration = data_item['video_duration']

                if self.dset_name == 'tacos':
                    if vid.endswith('.avi'):
                        vid = vid[0:-4]
                    self.fps = float(data_item['fps'])
                    duration = float(data_item['num_frames']) / self.fps
            else:
                duration = data_item['duration']
            results_per_vid = []
            sentences = []
            timestamps = []

            for timestamp, sentence in zip(data_item['timestamps'], data_item['sentences']):
                if timestamp[1] < timestamp[0]:
                    timestamp = [timestamp[1], timestamp[0]]

                if self.fps != 1:
                    timestamp = [timestamp[0] / self.fps, timestamp[1] / self.fps]
                sentences.append(sentence)
                timestamps.append(timestamp)
                record = dict(qid=self.idx_counter,
                              query=sentence,
                              duration=duration,
                              vid=vid,
                              relevant_windows=[timestamp])
                results_per_vid.append(record)
                self.idx_counter += 1

                results_per_vid = self._add_negative_query(results_per_vid)
            results.extend(results_per_vid)
        return results

    def get_iou(self, a, b, c, d):
        if b < c or d < a:
            return 0.0
        else:
            l = sorted([a, b, c, d])
            (a, b, c, d) = l
            return (c - b) / (d - a)

    def convert(self):
        meta = {}
        for phase in self.phase:
            intra_neg_cnt = 0
            intra_pos_cnt = 0
            data = load_json(os.path.join(self.root_dir, f'{phase}.json'))
            results = self.process_data(data, phase)
            save_jsonl(results, os.path.join(self.root_dir, f'{phase}.jsonl'))

            for result in results:
                meta[result['qid']] = {'relevant_windows': result['relevant_windows'],
                                       'query': result['query'],
                                       'vid': result['vid']}
                if len(result['intra_pos_query_ids']) > 0:
                    intra_pos_cnt += 1
                if len(result['intra_neg_query_ids']) > 0:
                    intra_neg_cnt += 1

            print(f'Phase: {phase}, total: {len(results)}, '
                  f'intra_pos: {intra_pos_cnt} ({round(intra_pos_cnt / len(results) * 100, 2)}%), '
                  f'intra_neg: {intra_neg_cnt} ({round(intra_neg_cnt / len(results) * 100, 2)}%)')
        save_json(meta, os.path.join(self.root_dir, 'meta_by_qid.json'))
        # save_json(meta, os.path.join('meta_by_qid.json'))

        print('done.')



