from argparse import Namespace
import json
import os

from .constants import *

class Config:
    
    CONFIG_FNAME = 'config.json'

    def __init__(self, args : Namespace=None):

        if not args: return

        #
        self.task_type = getattr(args, 'task_type', 'prediction')
        self.max_tokens = getattr(args, 'max_tokens', 1024)
        self.max_state_tokens = getattr(args, 'max_state_tokens', 25)
        self.temperature = getattr(args, 'temperature', -1)

        self.few_shot_k = getattr(args, 'few_shot_k', 1)

        self.model_type = getattr(args, 'model_type', None)
        self.model_path = getattr(args, 'model_path', None)
        self.profile_path = getattr(args, 'profile_path', None)
        self.chromadb_index =  getattr(args, 'chromadb_index', None)
        self.agent_type = getattr(args, 'agent_type', None)
        self.dataset = getattr(args, 'dataset', None)
        self.dataset_range = getattr(args, 'dataset_range', None)
        self.model_name = getattr(args, 'model_name', None)
        
        self.action_types = _get_actions(self.dataset)
        
        self.dataset_name = self.dataset.split('/')[-1]
        self.dataset_split = getattr(args, 'split', None)

        self.exp_id = self._get_experiment_id()
        self.exp_dir = os.path.join(EXPERIMENTS_DIR, self.dataset_name, self.exp_id)
        self._file_path = os.path.join(self.exp_dir, Config.CONFIG_FNAME)
        self.results_file = os.path.join(self.exp_dir, 'results.tsv')
        self.predictions_file = os.path.join(self.exp_dir, 'predictions.out')
        self.prompt_file = getattr(args, 'prompt_file', os.path.join(PROMPT_DIR, self.dataset + '.py'))
        self.agent_file = getattr(args, 'agent_file', os.path.join(SPEC_DIR, self.agent_type + '.txt'))
        
    def _get_experiment_id(self):
        ret_str = []
        ret_str.append(self.dataset)
        ret_str.append(self.agent_type)
        ret_str.append(self.few_shot_k)
        ret_str.append(self.task_type)
        ret_str.append(self.model_name)
        if self.dataset_range: ret_str.append(self.dataset_range)
        if self.dataset_split is not None: ret_str.append(self.dataset_split)
        ret_str = '_'.join([str(x) for x in ret_str])
        return ret_str

    def write_config(self):
        with open(self._file_path, 'w') as f:
            json.dump(self.__dict__, f, indent=4)
            
    @classmethod
    def load_config(cls, exp_dir):
        self = cls()
        file_path = os.path.join(exp_dir, Config.CONFIG_FNAME)
        with open(file_path, 'r') as f:
            for field, value in json.load(f).items():
                setattr(self, field, value)
        return self

def _get_actions(dataset : str):
    if dataset.startswith('gsm8k'):
        return ['Calculator']
    elif dataset.startswith('fever') or dataset.startswith('hotpot'):
        return ['Search', 'Lookup']
    else: return []