import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc

import random
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
np.random.seed(42)
set_seed(42)
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id


os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))

device = 'cuda'
model_name = args.model_name
model_data_path = model_name.split('/')[1]


direc_name = f''

train_vec = np.load(f'{direc_name}/encoding1_arr.npy')
train_lab = np.load(f'{direc_name}/encoding1_labels.npy')
model_save_path = f'{direc_name}/model.pth'
prots_path = f'{direc_name}/prot.npy'

with open(f'{direc_name}/lang_mapping.json','r') as file:
    lang_mapping = json.load(file)




os.makedirs(direc_name,exist_ok = True)

generations_prot_save_path = direc_name+f'/modified_prot_p10.jsonl'


tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,

    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

    
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'

with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)

del_ids_path = f"{direc_name}/encoding1_inf_ind.npy"
del_arr = np.load(del_ids_path)
# Find indices of rows in train_vec that contain NaN or Inf
bad_idx = np.where(~np.isfinite(train_vec).all(axis=1))[0]

# Build mask of good rows
good_mask = np.ones(len(train_vec), dtype=bool)
good_mask[bad_idx] = False

# Filter arrays
train_vec_f = train_vec[good_mask]
train_lab_f = train_lab[good_mask]
new_ds = ds.select(np.nonzero(good_mask)[0])
############################### getting new prot inds #####################

feat = train_vec
lab = train_lab

feat_img = list(feat)
labels=list(lab)

class DuelCNNWrapper(nn.Module):
    def __init__(self,vec_len):
        super(DuelCNNWrapper, self).__init__()
        
        self.additional_layer = nn.Sequential(
            nn.Linear(vec_len, vec_len),
            nn.InstanceNorm1d(vec_len),
            nn.ReLU()

           
        
        )

    def forward(self, x):
         
        x = self.additional_layer(x) 
        return x
        
vec_len = feat.shape[1]

model = DuelCNNWrapper(vec_len).to('cuda')
model.load_state_dict(torch.load(model_save_path))

class CustomDataset(Dataset):
    def __init__(self, img, labels, transform=None):
        self.img = img
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        # Fetch the image and label corresponding to the index
        img = self.img[idx]
        label = self.labels[idx]
        
        # Apply the transformation if provided
        if self.transform:
            img = self.transform(img)
        
        return img, label

train_dataset = CustomDataset( feat_img, labels)
indices = list(range(len(train_dataset)))

indices_shuffled = torch.randperm(len(indices)).tolist()

shuffled_dataset = torch.utils.data.Subset(train_dataset, indices_shuffled)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False,num_workers = 4,drop_last = True)
new_feat = []
new_labels= []

for img,label in tqdm(train_loader):
    
    img = img.to('cuda')
    new_arr = model(img)
    new_arr = new_arr.detach().cpu().numpy()
    # new_labels = new_labels.detach().cpu().numpy()

    for arr in new_arr:
        new_feat.append(arr)

    for lab in label:
        new_labels.append(lab)

new_arr = np.array(new_feat)
new_labels = np.array(new_labels)

inds = []
nn_human = np.load(prots_path)


for i, vec in enumerate(nn_human):
    # Compute L2 distances between vec and all feat vectors
    dists = np.linalg.norm(new_arr - vec, axis=1)
    closest_idx = np.argmin(dists)

    inds.append(closest_idx)
#########################getting new prot inds

pyth_ind = []
for ind in inds:
    if ds['lang'][ind]=='python':
        
        pyth_ind.append(ind)






if model_name == 'tiiuae/Falcon3-1B-Base':

    few_shots = [(ds['problem'][inds],ds['solution'][inds]) for inds in new_inds]

else:
    few_shots = [(ds['problem'][ind],ds['solution'][ind]) for ind in pyth_ind]


def build_icl_prompt(few_shots,test_problem):
    prompt = ""
    for prob, sol in few_shots:
        # prompt+= f"You are an expert Python programmer, and here is your task: {prob} \n[BEGIN]\n{sol}\n[DONE]"
        prompt += f"You are an expert programmer, and here is your task: {prob}\n[BEGIN]\n{sol}\n[DONE]\n\n"

    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt


test_problems = mbpp_subset
completions = {}


prompts, task_ids = [], []
for item in test_problems:
    task_ids.append(str(item["task_id"]))
    prompts.append(build_icl_prompt(few_shots,item["text"]))

batch_size = 1   # tune for your VRAM
num_return_sequences = 10
completions = {}

for i in tqdm(range(0, len(prompts), batch_size), desc="Generating (batched)"):
    batch_prompts = prompts[i:i+batch_size]
    batch_task_ids = task_ids[i:i+batch_size]

    outs = generator(
        batch_prompts,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        num_return_sequences=num_return_sequences,
        batch_size=len(batch_prompts),  # must match current batch length
    )

    # outs: List[List[Dict]] → [batch_size][num_return_sequences]
    for j, tid in enumerate(batch_task_ids):
        completions[tid] = [o["generated_text"] for o in outs[j]]


with open(generations_prot_save_path, "w") as f:
    for task_id, sols in completions.items():
        for s in sols:
            f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": s}) + "\n")



print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
