from transformers import pipeline
from datasets import load_dataset
import torch
from time import time
import json


# PATH = 'results/magicoder/'
# FILE = 't5_transfer'

# import matplotlib.pyplot as plt
# import numpy as np

# for postfix in ['_he_mbpp', '_mbpp_he']:

#     with open(PATH + f'{FILE}{postfix}.json', 'r') as f:
#         res = json.load(f)

#     print(postfix, len(res))
#     for key, vals in res.items():
#         x = np.arange(len(vals))
#         plt.plot(x, vals, label=key)
#         print(f'{key}: {vals[-1]}')

#     plt.legend()
#     plt.savefig(PATH + f'{FILE}{postfix}_metrics.png')
#     plt.close()

# with open(PATH + FILE, 'r') as f:
#     results = json.load(f)

# import matplotlib.pyplot as plt
# import numpy as np

# x = np.arange(100)
# metrics = {m_name: np.empty((5, 100)) for m_name in results[0]}
# for i, res in enumerate(results):
#     for m_name in res:
#         # plt.plot(x, res[m_name], label=m_name)
#         metrics[m_name][i, :] = res[m_name]

# for m_name, vals in metrics.items():
#     mean = vals.mean(axis=0)
#     std = vals.std(axis=0)
#     plt.plot(x, mean, label=m_name)
#     plt.fill_between(x, mean - std, mean + std, alpha=0.1)

# plt.legend()
# plt.savefig(PATH + FILE.replace('json', 'png'))

# for m_name, vals in metrics.items():
#     print(f'{m_name}: {vals.mean(axis=0)[-1]} +- {vals.std(axis=0)[-1]}')


# PATH = 'results/qwencoder/'
# FILE = 't5_coder_metrics_he.json'

# with open(PATH + FILE, 'r') as f:
#     results = json.load(f)

# import matplotlib.pyplot as plt
# import numpy as np

# x = np.arange(100)
# metrics = {m_name: np.empty((5, 100)) for m_name in results[0]}
# for i, res in enumerate(results):
#     for m_name in res:
#         # plt.plot(x, res[m_name], label=m_name)
#         metrics[m_name][i, :] = res[m_name]

# for m_name, vals in metrics.items():
#     mean = vals.mean(axis=0)
#     std = vals.std(axis=0)
#     plt.plot(x, mean, label=m_name)
#     plt.fill_between(x, mean - std, mean + std, alpha=0.1)

# plt.legend()
# plt.savefig(PATH + FILE.replace('json', 'png'))

# for m_name, vals in metrics.items():
#     print(f'{m_name}: {vals.mean(axis=0)[-1]} +- {vals.std(axis=0)[-1]}')

# Generate Qwencoder code

pipe = pipeline("text-generation", model="Qwen/Qwen2.5-Coder-7B", device='cuda', torch_dtype=torch.bfloat16)
ds = load_dataset("google-research-datasets/mbpp", "full")

results = dict()

for item in ds['test']:

    i = time()
    task_id = item['task_id']
    text = item['text']
    test_list = item['test_list']
    test_setup_code = item['test_setup_code']
    challenge_test_list = item['challenge_test_list']
    message = [{
        "role": "user",
        "content": f"""
            You are an expert Python programmer, and here is your task: Write a function to find the similar elements from the given two tuple lists. Your code should pass these tests:

            assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)
            assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)
            assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)
            [BEGIN]
            def similar_elements(test_tup1, test_tup2):
              res = tuple(set(test_tup1) & set(test_tup2))
              return (res)
            [DONE]

            {text} Your code should pass these tests:
        """ + '\n'.join(test_list) + '\n[BEGIN]'
    },]
    t = time()
    with torch.no_grad():
        result = pipe(message)[0]['generated_text']
    results[task_id] = result
    print(time() - t)

with open('/app/results_qwen_25_7b_mbpp.json', 'w') as f:
    json.dump(results, f)

# Parse generation results

with open('/app/results_qwen_25_7b_mbpp.json', 'r') as f:
    results = json.load(f)

print(results[list(results.keys())[0]][1])
print(list(results.keys()))

results_new = dict()

ds = load_dataset("google-research-datasets/mbpp", "full")
for item in ds['test']:
    task_id = item['task_id']
    text = item['text']
    test_list = item['test_list']
    test_setup_code = item['test_setup_code']
    challenge_test_list = item['challenge_test_list']
    # print(task_id, text)
    message = [{
        "role": "user",
        "content": f"""
            You are an expert Python programmer, and here is your task: Write a function to find the similar elements from the given two tuple lists. Your code should pass these tests:

            assert similar_elements((3, 4, 5, 6),(5, 7, 4, 10)) == (4, 5)
            assert similar_elements((1, 2, 3, 4),(5, 4, 3, 7)) == (3, 4)
            assert similar_elements((11, 12, 14, 13),(17, 15, 14, 13)) == (13, 14)
            [BEGIN]
            def similar_elements(test_tup1, test_tup2):
              res = tuple(set(test_tup1) & set(test_tup2))
              return (res)
            [DONE]

            {text} Your code should pass these tests:
        """ + '\n'.join(test_list) + '\n[BEGIN]'
    },]
    res = results[str(task_id)][1]
    assert res['role'] == 'assistant'
    results_new[str(task_id)] = res['content'].replace(message[0]['content'], '')


from eval_results import check_correctness


def collect_results(string: str, result: dict, task_id: int, index: int):

    current_position = 0
    while current_position < len(string) - 1:

        pos1 = string.find('[END]', current_position) + 6
        pos2 = string.find('[END]', pos1)
        pos3 = string.find('[BEGIN]', pos1, pos2)

        if pos1 == -1 or pos2 == -1:
            break

        if pos3 == -1:
            current_position = pos2 + 6
            continue

        pos_new_line = string.find('\n', pos1, pos2)
        task = string[pos1:pos_new_line].strip()

        asserts = string[pos_new_line + 1:pos3].strip()
        code = string[pos3 + 8:pos2].strip()

        check_program = code + '\n' + asserts
        correctness = check_correctness(check_program, 30, task, string)

        if code.strip() == '':
            current_position = pos2 + 6
            continue

        result[index] = {
            'task': task,
            'task_id': task_id,
            'code': code,
            'tests': asserts,
            'result': correctness['result']
        }
        current_position = pos2 + 6
        index += 1
        task_id = None

    return index


result = dict()
INDEX = 0
for task_id, res in results_new.items():
    INDEX = collect_results(res, result, int(task_id), INDEX)

with open('/app/results_qwen27_7b-mbpp_evaluated.json', 'w') as f:
    json.dump(result, f)

print(INDEX)
