import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc
from sklearn.metrics.pairwise import cosine_similarity
import random
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
np.random.seed(42)
set_seed(42)
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id


os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))

device = 'cuda'
model_name = args.model_name
model_data_path = model_name.split('/')[1]


direc_name = f''

with open(f'{direc_name}/lang_mapping.json','r') as f:
    lang_mapping = json.load(f)


train_vec = np.load(f'{direc_name}/encoding1_arr.npy')
train_lab = np.load(f'{direc_name}/encoding1_labels.npy')
test_vec = np.load(f'{direc_name}/test_encoding.npy')
os.makedirs(direc_name,exist_ok = True)

generations_prot_save_path = direc_name+f'/sim_p10.jsonl'


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,

    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

    
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'

with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)

del_ids_path = f"{direc_name}/encoding1_inf_ind.npy"
del_arr = np.load(del_ids_path)

# Find indices of rows in train_vec that contain NaN or Inf
bad_idx = np.where(~np.isfinite(train_vec).all(axis=1))[0]

# Build mask of good rows
good_mask = np.ones(len(train_vec), dtype=bool)
good_mask[bad_idx] = False

# Filter arrays
train_vec_f = train_vec[good_mask]
train_lab_f = train_lab[good_mask]
new_ds = ds.select(np.nonzero(good_mask)[0])
# prob = ds['problem'][rand_int]
# sol = ds['solution'][rand_int]
# few_shots = [(prob,sol)]


def build_icl_prompt(few_shots,test_problem):
    prompt = ""
    for prob, sol in few_shots:
        # prompt+= f"You are an expert Python programmer, and here is your task: {prob} \n[BEGIN]\n{sol}\n[DONE]"
        prompt += f"You are an expert Python programmer, and here is your task: {prob}\n[BEGIN]\n{sol}\n[DONE]\n\n"

    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt


test_problems = mbpp_subset
completions = {}


for i,item in tqdm(enumerate(test_problems)):
    task_id = str(item['task_id'])
    test_problem = item['text']

    dists = np.linalg.norm(train_vec_f-test_vec[i],axis=1)
    idx = np.argmin(dists)
    
    few_shot_example = [(new_ds['problem'][idx],new_ds['solution'][idx])]
    # Compose ICL prompt
    icl_prompt = build_icl_prompt(few_shot_example, test_problem)

    # Generate code (change parameters as needed for your setup)
    output = generator(
        icl_prompt,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        # num_beams = 10
        temperature=0.6,
    top_p=0.9,
    num_return_sequences=10
    )

    completions_list = [o['generated_text'] for o in output]
    completions[task_id] = completions_list


with open(generations_prot_save_path, "w") as f:
    for task_id, sols in completions.items():
        for s in sols:
            f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": s}) + "\n")



print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
