import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc

import random
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
np.random.seed(42)
set_seed(42)
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id


os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))

device = 'cuda'
model_name = args.model_name
model_data_path = model_name.split('/')[1]


direc_name = f''

train_vec = np.load(f'{direc_name}/encoding1_arr.npy')
train_lab = np.load(f'{direc_name}/encoding1_labels.npy')

with open(f'{direc_name}/lang_mapping.json','r') as file:
    lang_mapping = json.load(file)




os.makedirs(direc_name,exist_ok = True)

generations_prot_save_path = direc_name+f'/div_p10.jsonl'


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,

    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

    
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'

with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)

del_ids_path = f"{direc_name}/encoding1_inf_ind.npy"
del_arr = np.load(del_ids_path)
# Find indices of rows in train_vec that contain NaN or Inf
bad_idx = np.where(~np.isfinite(train_vec).all(axis=1))[0]

# Build mask of good rows
good_mask = np.ones(len(train_vec), dtype=bool)
good_mask[bad_idx] = False

# Filter arrays
train_vec_f = train_vec[good_mask]
train_lab_f = train_lab[good_mask]
new_ds = ds.select(np.nonzero(good_mask)[0])


mean_vec = train_vec_f[train_lab_f == lang_mapping['python']].mean(axis=0)
dists = np.linalg.norm(train_vec_f - mean_vec, axis=1)
closest_idx = np.argmin(dists)
prob = new_ds['problem'][closest_idx]
sol = new_ds['solution'][closest_idx]


few_shots = [(prob,sol)]


def build_icl_prompt(few_shots,test_problem):
    prompt = ""
    for prob, sol in few_shots:
        # prompt+= f"You are an expert programmer, and here is your task: {prob} \n[BEGIN]\n{sol}\n[DONE]"
        prompt += f"You are an expert Python programmer, and here is your task: {prob}\n[BEGIN]\n{sol}\n[DONE]\n\n"

    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt


test_problems = mbpp_subset
completions = {}


prompts, task_ids = [], []
for item in test_problems:
    task_ids.append(str(item["task_id"]))
    prompts.append(build_icl_prompt(few_shots,item["text"]))

batch_size = 1   
num_return_sequences = 10
completions = {}

for i in tqdm(range(0, len(prompts), batch_size), desc="Generating (batched)"):
    batch_prompts = prompts[i:i+batch_size]
    batch_task_ids = task_ids[i:i+batch_size]

    outs = generator(
        batch_prompts,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        num_return_sequences=num_return_sequences,
        batch_size=len(batch_prompts), 
    )

    
    for j, tid in enumerate(batch_task_ids):
        completions[tid] = [o["generated_text"] for o in outs[j]]


with open(generations_prot_save_path, "w") as f:
    for task_id, sols in completions.items():
        for s in sols:
            f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": s}) + "\n")



print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
