# --------------------------------------------------------
# mcan-vqa (Deep Modular Co-Attention Networks)
# Licensed under The MIT License [see LICENSE for details]
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------

import os

class PATH:
    def __init__(self):

        # vqav2 dataset root path
        self.dir = '/work/06792/kt22354/maverick2/VQA'
        self.DATASET_PATH = self.dir + '/datasets/vqa/'

        # bottom up features root path
        self.FEATURE_PATH = self.dir + '/datasets/coco_extract/'

        self.init_path()


    def init_path(self):

        self.IMG_FEAT_PATH = {
            'train': self.FEATURE_PATH + 'train2014/',
            'val': self.FEATURE_PATH + 'val2014/',
            'test': self.FEATURE_PATH + 'test2015/',
        }

        self.QUESTION_PATH = {
            'train': self.DATASET_PATH + 'v2_OpenEnded_mscoco_train2014_questions.json',
            'val': self.DATASET_PATH + 'v2_OpenEnded_mscoco_val2014_questions.json',
            'test': self.DATASET_PATH + 'v2_OpenEnded_mscoco_test2015_questions.json',
            'vg': self.DATASET_PATH + 'VG_questions.json',
        }

        self.ANSWER_PATH = {
            'train': self.DATASET_PATH + 'v2_mscoco_train2014_annotations.json',
            'val': self.DATASET_PATH + 'v2_mscoco_val2014_annotations.json',
            'vg': self.DATASET_PATH + 'VG_annotations.json',
        }

        self.RESULT_PATH = self.dir + '/results/result_test/'
        self.PRED_PATH = self.dir + '/results/pred/'
        self.CACHE_PATH = self.dir + '/results/cache/'
        self.LOG_PATH = self.dir + '/results/log/'
        self.CKPTS_PATH = self.dir + '/ckpts/'

        if 'result_test' not in os.listdir(self.dir + '/results'):
            os.mkdir(self.dir + '/results/result_test')

        if 'pred' not in os.listdir(self.dir + '/results'):
            os.mkdir(self.dir + '/results/pred')

        if 'cache' not in os.listdir(self.dir + '/results'):
            os.mkdir(self.dir + '/results/cache')

        if 'log' not in os.listdir(self.dir + '/results'):
            os.mkdir(self.dir + '/results/log')

        if 'ckpts' not in os.listdir(self.dir + '/'):
            os.mkdir(self.dir + '/ckpts')


    def check_path(self):
        print('Checking dataset ...')

        for mode in self.IMG_FEAT_PATH:
            if not os.path.exists(self.IMG_FEAT_PATH[mode]):
                print(self.IMG_FEAT_PATH[mode] + 'NOT EXIST')
                exit(-1)

        for mode in self.QUESTION_PATH:
            if not os.path.exists(self.QUESTION_PATH[mode]):
                print(self.QUESTION_PATH[mode] + 'NOT EXIST')
                exit(-1)

        for mode in self.ANSWER_PATH:
            if not os.path.exists(self.ANSWER_PATH[mode]):
                print(self.ANSWER_PATH[mode] + 'NOT EXIST')
                exit(-1)

        print('Finished')
        print('')

