import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc


import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

set_seed(42)
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id


os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))

device = 'cuda'
model_name = args.model_name
model_data_path = model_name.split('/')[1]


direc_name = f''


os.makedirs(direc_name,exist_ok = True)

generations_prot_save_path = direc_name+f'/base_p10.jsonl'


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,

    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

    

mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'

with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)





def build_icl_prompt(test_problem):
    prompt = ""

    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt


test_problems = mbpp_subset
completions = {}


prompts, task_ids = [], []
for item in test_problems:
    task_ids.append(str(item["task_id"]))
    prompts.append(build_icl_prompt(item["text"]))

batch_size = 1  
num_return_sequences = 10
completions = {}

for i in tqdm(range(0, len(prompts), batch_size), desc="Generating (batched)"):
    batch_prompts = prompts[i:i+batch_size]
    batch_task_ids = task_ids[i:i+batch_size]

    outs = generator(
        batch_prompts,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        num_return_sequences=num_return_sequences,
        batch_size=len(batch_prompts),  
    )

    # outs: List[List[Dict]] → [batch_size][num_return_sequences]
    for j, tid in enumerate(batch_task_ids):
        completions[tid] = [o["generated_text"] for o in outs[j]]



# new_completions = {}
# for task_id,solution in completions.items():
#     for sol in solution:
#         new_completions[task_id]=sol
    

# with open(generations_prot_save_path, "w") as f:
#     for task_id, output in completions.items():
#         f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": output}) + "\n")

with open(generations_prot_save_path, "w") as f:
    for task_id, sols in completions.items():
        for s in sols:
            f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": s}) + "\n")


# with open("samples_mbpp_p10_gemma3.jsonl") as fin, open("samples_mbpp_p10_gemma3_new_org.jsonl", "w") as fout:
#     [fout.write(json.dumps({"task_id": d["task_id"], "solution": s}) + "\n") for d in map(json.loads, fin) for s in d["solution"]]

print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
