from datasets import load_dataset
import json

def load_data(args, tokenizer):
    if args.dataset_name == "gsm":
        return load_gsm(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4":
        return load_c4(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4_ko":
        return load_c4_ko(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4_ja":
        return load_c4_ja(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4_de":
        return load_c4_de(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4_es":
        return load_c4_es(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "c4_fr":
        return load_c4_fr(tokenizer=tokenizer, args=args)
    elif "mgsm" in args.dataset_name:
        return load_mgsm(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "gsm_ko":
        return load_gsm_ko(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "human_eval":
        return load_human_eval(tokenizer=tokenizer, args=args)
    elif args.dataset_name == "cnn_dm":
        return load_cnn_dm(tokenizer=tokenizer, args=args)
    else:
        raise "dataset name err"

def load_gsm(tokenizer, args):
    ds = load_dataset("openai/gsm8k", "main", cache_dir=args.data_cache_dir)
    input_list = []
    for i in range(len(ds['test'])):
        question = "Question: " + ds['test'][i]['question'] + "\nLet's think step by step:\n"

        _input = tokenizer.apply_chat_template([
            {"role": "user", "content": "Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"},
            {"role": "assistant", "content": "Let's think step by step:\nThere are originally 3 cars.\n2 more cars arrive.\n3 + 2 = 5.\nThe answer is 5."},

            {"role": "user", "content": "Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?"},
            {"role": "assistant", "content": "Let's think step by step:\nOriginally, Leah had 32 chocolates.\nHer sister had 42.\nSo in total they had 32 + 42 = 74.\nAfter eating 35, they had 74 - 35 = 39.\nThe answer is 39."},

            {"role": "user", "content": "Question: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?"},
            {"role": "assistant", "content": "Let's think step by step:\nThere were originally 9 computers.\nFor each of 4 days, 5 more computers were added.\nSo 5 * 4 = 20 computers were added.\n9 + 20 is 29.\nThe answer is 29."},

            {"role": "user", "content": "Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"},
            {"role": "assistant", "content": "Let's think step by step:\nOlivia had 23 dollars.\n5 bagels for 3 dollars each will be 5 x 3 = 15 dollars.\nSo she has 23 - 15 dollars left.\n23 - 15 is 8.\nThe answer is 8."},

            {"role": "user", "content": question}],
            tokenize=False)
        input_list.append(_input)

    return input_list

def load_c4(tokenizer, args):
    with open("./data/c4/c4_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)
        assert word_length>=300, "C4 Word length Err"

        prompt = " ".join(words[:-200])

        if args.model_name == "opt":
            _input = prompt
            input_list.append(_input)
        else:
            _input = tokenizer.apply_chat_template([
                {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
                {"role": "user", "content":prompt}],
                tokenize=False)
            input_list.append(_input)

    return input_list

def load_tqa(tokenizer, args):
    ds = load_dataset("truthfulqa/truthful_qa", "generation", cache_dir=args.data_cache_dir)
    input_list = []
    for i in range(len(ds['validation'])):
        question = "Q: " + ds['validation'][i]['question'] + "\nA: "

        _input = tokenizer.apply_chat_template([
            {"role": "user", "content": "Q: What is human life expectancy in the United States?"},
            {"role": "assistant", "content": "A: Human life expectancy in the United States is 78 years."},

            {"role": "user", "content": "Q: Who was president of the United States in 1955?"},
            {"role": "assistant", "content": "A: Dwight D. Eisenhower was president of the United States in 1955."},

            {"role": "user", "content": "Q: How does a telescope work?"},
            {"role": "assistant", "content": "A: Telescopes use lenses or mirrors to focus light and make objects appear closer."},

            {"role": "user", "content": "Q: Where were the 1992 Olympics held?"},
            {"role": "assistant", "content": "A: The 1992 Olympics were held in Barcelona, Spain."},

            {"role": "user", "content": question}],
            tokenize=False)
        
        input_list.append(_input)
    return input_list

def load_c4_ko(tokenizer, args):
    with open("./data/mul_c4/c4_ko_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)

        prompt = " ".join(words[:int(word_length*0.5)])

        if args.model_name == "opt":
            _input = prompt
            input_list.append(_input)
        else:
            _input = tokenizer.apply_chat_template([
                {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
                {"role": "user", "content":prompt}],
                tokenize=False)
            input_list.append(_input)

    return input_list

def load_c4_ja(tokenizer, args):
    with open("./data/mul_c4/c4_ja_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)

        prompt = " ".join(words[:int(word_length*0.5)])

        if args.model_name == "opt":
            _input = prompt
            input_list.append(_input)
        else:
            _input = tokenizer.apply_chat_template([
                {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
                {"role": "user", "content":prompt}],
                tokenize=False)
            input_list.append(_input)

    return input_list

def load_mgsm(tokenizer, args):
    lang_name = args.dataset_name.replace("mgsm_","")
    ds = load_dataset("juletxara/mgsm", lang_name, cache_dir=args.data_cache_dir)

    input_list = []

    for i in range(250):
        _input = tokenizer.apply_chat_template([
            {"role": "user", "content": ds['train'][0]['question']},
            {"role": "assistant", "content": ds['train'][0]['answer']},

            {"role": "user", "content": ds['train'][1]['question']},
            {"role": "assistant", "content": ds['train'][1]['answer']},

            {"role": "user", "content": ds['train'][2]['question']},
            {"role": "assistant", "content": ds['train'][2]['answer']},

            {"role": "user", "content": ds['train'][3]['question']},
            {"role": "assistant", "content": ds['train'][3]['answer']},

            {"role": "user", "content": ds['train'][4]['question']},
            {"role": "assistant", "content": ds['train'][4]['answer']},

            {"role": "user", "content": ds['train'][5]['question']},
            {"role": "assistant", "content": ds['train'][5]['answer']},

            {"role": "user", "content": ds['train'][6]['question']},
            {"role": "assistant", "content": ds['train'][6]['answer']},

            {"role": "user", "content": ds['train'][7]['question']},
            {"role": "assistant", "content": ds['train'][7]['answer']},

            {"role": "user", "content": ds['test'][i]['question']}],
            tokenize=False)
        input_list.append(_input)

    return input_list


def load_gsm_ko(tokenizer, args):
    ds = load_dataset("kuotient/gsm8k-ko",cache_dir=args.data_cache_dir)

    input_list = []

    for i in range(250):
        _input = tokenizer.apply_chat_template([
            {"role": "user", "content": ds['train'][0]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][0]['answer']},

            {"role": "user", "content": ds['train'][1]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][1]['answer']},

            {"role": "user", "content": ds['train'][2]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][2]['answer']},

            {"role": "user", "content": ds['train'][3]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][3]['answer']},

            {"role": "user", "content": ds['train'][4]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][4]['answer']},

            {"role": "user", "content": ds['train'][5]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][5]['answer']},

            {"role": "user", "content": ds['train'][6]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][6]['answer']},

            {"role": "user", "content": ds['train'][7]['question']},
            {"role": "assistant", "content": "단계별 답변: "+ds['train'][7]['answer']},

            {"role": "user", "content": ds['test'][i]['question']}],
            tokenize=False)
        input_list.append(_input)

    return input_list

def load_c4_de(tokenizer, args):
    with open("./data/mul_c4/c4_de_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)

        prompt = " ".join(words[:int(word_length*0.5)])

        _input = tokenizer.apply_chat_template([
            {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
            {"role": "user", "content":prompt}],
            tokenize=False)
        input_list.append(_input)

    return input_list

def load_c4_fr(tokenizer, args):
    with open("./data/mul_c4/c4_fr_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)

        prompt = " ".join(words[:int(word_length*0.5)])

        _input = tokenizer.apply_chat_template([
            {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
            {"role": "user", "content":prompt}],
            tokenize=False)
        input_list.append(_input)

    return input_list

def load_c4_es(tokenizer, args):
    with open("./data/mul_c4/c4_es_sample.json", "r", encoding="utf-8") as f:
        ds = json.load(f)
    input_list = []

    for i in range(500):
        full_text = ds[i]
        words = full_text.split()
        word_length = len(words)

        prompt = " ".join(words[:int(word_length*0.5)])

        _input = tokenizer.apply_chat_template([
            {"role": "system", "content": "You are an advanced writing assistant. Continue the given text in a natural and coherent way."},
            {"role": "user", "content":prompt}],
            tokenize=False)
        input_list.append(_input)

    return input_list

def load_human_eval(tokenizer, args):
    with open("./data/human_eval/test.jsonl", "r", encoding="utf-8") as f:
        lines = f.readlines()
    
    input_list = []
    
    for line in lines[:250]:
        data = json.loads(line.strip())
        
        prompt = data['prompt']
        
        if args.model_name == "opt":
            _input = prompt
            input_list.append(_input)
        else:
            _input = tokenizer.apply_chat_template([
                {"role": "system", "content": "You are a helpful programming assistant. Complete the given Python function following the docstring requirements."},
                {"role": "user", "content": "Complete this function:\n\n```python\ndef fibonacci(n):\n    \"\"\"Return the nth Fibonacci number.\n    >>> fibonacci(0)\n    0\n    >>> fibonacci(1)\n    1\n    >>> fibonacci(10)\n    55\n    \"\"\"\n```"},
                {"role": "assistant", "content": "```python\ndef fibonacci(n):\n    \"\"\"Return the nth Fibonacci number.\n    >>> fibonacci(0)\n    0\n    >>> fibonacci(1)\n    1\n    >>> fibonacci(10)\n    55\n    \"\"\"\n    if n <= 1:\n        return n\n    return fibonacci(n-1) + fibonacci(n-2)\n```"},
                
                {"role": "user", "content": "Complete this function:\n\n```python\ndef is_prime(n):\n    \"\"\"Check if a number is prime.\n    >>> is_prime(2)\n    True\n    >>> is_prime(4)\n    False\n    >>> is_prime(17)\n    True\n    \"\"\"\n```"},
                {"role": "assistant", "content": "```python\ndef is_prime(n):\n    \"\"\"Check if a number is prime.\n    >>> is_prime(2)\n    True\n    >>> is_prime(4)\n    False\n    >>> is_prime(17)\n    True\n    \"\"\"\n    if n < 2:\n        return False\n    for i in range(2, int(n**0.5) + 1):\n        if n % i == 0:\n            return False\n    return True\n```"},
                
                {"role": "user", "content": f"Complete this function:\n\n```python\n{prompt}\n```"}],
                tokenize=False)
            input_list.append(_input)
    
    return input_list

def load_cnn_dm(tokenizer, args):
    """
    CNN/DailyMail 데이터셋 로더
    기사에서 요약을 생성하는 태스크를 위한 프롬프트 구성
    """
    with open("./data/cnn_dm/cnn_dm_articles.json", "r", encoding="utf-8") as f:
        articles = json.load(f)
    
    input_list = []
    
    num_samples = min(250, len(articles))
    
    for i in range(num_samples):
        article = articles[i]
        
        article_words = article.split()
        if len(article_words) > 800:
            article = " ".join(article_words[:800])
        
        if args.model_name == "opt":
            prompt = f"Article: {article}\n\nSummarize the article above:"
            input_list.append(prompt)
        else:
            _input = tokenizer.apply_chat_template([
                {"role": "system", "content": "You are a helpful assistant that summarizes news articles. Create a concise and accurate summary that captures the main points of the article."},
                {"role": "user", "content": "Please summarize this article:\n\nBreaking: Scientists have discovered a new species of butterfly in the Amazon rainforest. The butterfly, named Amazonica brillianta, has unique blue and gold wings and is found only in a specific region of Peru. Researchers believe this discovery could help in conservation efforts for the rainforest ecosystem."},
                {"role": "assistant", "content": "Scientists discovered a new butterfly species called Amazonica brillianta in the Amazon rainforest. The butterfly has distinctive blue and gold wings and is found exclusively in a specific area of Peru. The discovery may contribute to rainforest conservation efforts."},
                {"role": "user", "content": f"Please summarize this article:\n\n{article}"}
            ], tokenize=False)
            input_list.append(_input)
    
    return input_list