import os
import json


class DataLoder:
    def __init__(self, root_path):
        self.root_path = root_path

    def load_data(self, data_path=None, data_name_1=None, data_name_2=None, data_name_3=None):
        data_path_ = os.path.join(self.root_path, data_path, data_name_1)
        print(f"[INFO] Now loading {data_path}, and its data path is {data_path_}")
        if data_path in ['AddSub/', 'MultiArith/', 'SingleEq/']:
            with open(data_path_, 'r') as f:
                data = json.load(f)
                Equations = [i.get('lEquations')[0] for i in data]  
                Solutions = [i.get('lSolutions')[0] for i in data] 
                Questions = [i.get('sQuestion') for i in data]     
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path=='SVAMP/':
            with open(data_path_, 'r') as f:
                data = json.load(f)
                Equations = [i.get('Equation') for i in data]      
                Questions = [i.get('Question') for i in data]    
                Solutions = [i.get('Answer') for i in data]       
                Bodys = [i.get('Body') for i in data]            
                Questions = [Bodys[i]+' '+Questions[i] for i in range(len(Questions))]
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path=='AQuA/':
            with open(data_path_, 'r') as f:
                decoder = json.JSONDecoder()
                data = f.readlines()
                data = [decoder.raw_decode(i)[0] for i in data]
                Questions = [i.get('question') for i in data]     
                Options = [i.get('options') for i in data]      
                Options = ['(' + '('.join(i) for i in Options]
                Options = [i.replace('(', ' (').replace(')', ') ') for i in Options]
                Rationale = [i.get('rationale') for i in data]  
                Answers = [i.get('correct') for i in data]         
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
                Questions = [f'{Questions[i]} Answer Choices: {Options[i]}' for i in range(len(Questions))]
            return Questions, Rationale, Answers
        elif data_path=='grade-school-math/':
            with open(data_path_, 'r') as f:
                decoder = json.JSONDecoder()
                data = f.readlines()
                data = [decoder.raw_decode(i)[0] for i in data]
                Questions = [i.get('question') for i in data]                 
                Equations = [i.get('answer').split('### ')[0] for i in data]   
                Solutions = [i.get('answer').split('### ')[1] for i in data]  
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path == 'StrategyQA/':
            with open(data_path_, 'r') as f:
                data = json.load(f)['examples']
                Questions = ['Yes or No: '+i.get('input') for i in data]     
                Equations = [i.get('target') for i in data]                  
                Solutions = [i.get('target_scores') for i in data]            
                solutions = []
                for answer in Solutions:
                    for key, value in answer.items():
                        if value==1:
                            solutions.append(key)
                Solutions = solutions
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path == 'coin_flip/':
            with open(data_path_, 'r') as f:
                data = json.load(f)['examples']
                Questions = ['Yes or No: '+i.get('question') for i in data]    
                Equations = ['None' for i in data]                            
                Solutions = [i.get('answer') for i in data]                   
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path == 'CommonsenseQA/':
            with open(data_path_, 'r') as f:
                decoder = json.JSONDecoder()
                data = f.readlines()
                data = [decoder.raw_decode(i)[0] for i in data]
                Questions, Solutions = [], []
                for sub_data in data:
                    choices = sub_data.get('question').get('choices')
                    choices_str = "Answer Choices:"
                    for choice in choices:
                        label = choice.get('label')
                        text = choice.get('text')
                        choices_str += f' ({label}) {text}'
                    question = sub_data.get('question').get('stem').strip() + " " + choices_str
                    answer = sub_data.get('answerKey')
                    Questions.append(question)                        
                    Solutions.append(answer)                          
                Equations = ['None' for i in range(len(Questions))]    
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path == 'last_letters/':
            with open(data_path_, 'r') as f:
                data = json.load(f)['examples']
                Questions = [i.get('question') for i in data]  
                Equations = ['None' for i in data]           
                Solutions = [i.get('answer') for i in data]    
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path in ['GSM-IC2/', 'GSM-ICM/']:
            with open(data_path_, 'r') as f:
                data = json.load(f)
                Equations = [i.get('sentence_template').format(role=i.get('role'), number=i.get('number')) for i in data]  
                Solutions = [i.get('answer') for i in data] 
                Questions = [i.get('new_question') for i in data]     
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
        elif data_path=='SingleOp/':
            with open(data_path_, 'r') as f:
                decoder = json.JSONDecoder()
                data = f.readlines()
                data = [decoder.raw_decode(i)[0] for i in data]
                Questions = [i.get('input') for i in data]                 
                Equations = ['None' for i in data]   
                Solutions = [i.get('target') for i in data]   
                avg_tokens_each_question_ = sum([len(i.split(' ')) for i in Questions]) / len(Questions)
                print(f'[INFO] Average tokens in each question: {avg_tokens_each_question_}, and the number of questions is: {len(Questions)}')
            return Questions, Equations, Solutions
