# from huggingface_hub import notebook_login
from tqdm import tqdm
import json
import argparse
import pandas as pd
import os
from collections import Counter
import numpy as np
from datasets import load_dataset, concatenate_datasets
import argparse
import re
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id


os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))


# torch.set_float32_matmul_precision('high')
device = 'cuda'
model_data_path = model_name.split('/')[1]
direc_name = f''

os.makedirs(direc_name,exist_ok = True)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, 
output_hidden_states=True,
torch_dtype=torch.float16 ,
    device_map='auto' 
)
tokenizer.pad_token = tokenizer.eos_token

def get_embeddings(texts, batch_size=64, max_length=1024, model_name=None):
    device = "cuda"


    embeddings = []
    
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        
        
        inputs = tokenizer(batch_texts, padding="max_length", truncation=True, 
                          max_length=max_length, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        # inputs = {k: v for k, v in inputs.items()}

        
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            
        
        last_hidden_states = outputs.hidden_states[-1]

        mask = inputs["attention_mask"].unsqueeze(-1)

        mask_sum = torch.clamp(mask.sum(dim=1), min=1e-9)  
        mean_pooled = torch.sum(last_hidden_states * mask, dim=1) / mask_sum

        
        
        embeddings.append(mean_pooled.cpu().numpy())
    
    
    
    return np.vstack(embeddings)


def create_encodings(ds,p='problem',s='solution'):
    encoding1 = "This is the query being assigned:"+"  "+ ds[p]+"  "+"The following is the code solution to the query"+"  "+ds[s]
    
    return {'Encoding1':encoding1

}


def remove_lang(example,idx):
    encoding3 = example['solution']
    matches = re.search(r"```(.*?)```", encoding3, re.DOTALL)
    
    if matches:
        new= matches.group(1)
        code = new.split('\n',1)[1]
        # code = matches[0].split('\n', 1)[1] if '\n' in matches[0] else matches[0]
        # print(idx)
    else:
        code = encoding3  

    return {'Encoding3': code}

                
            



def check_inf_vals(embed,lang_ids):
    feat = embed
    has_nan = np.isnan(feat).any()
    has_inf = np.isinf(feat).any()
    zero_rows = np.all(feat == 0, axis=1)
    zero_indices = np.where(zero_rows)[0]
    inf_mask = np.isinf(feat).any(axis=1)
    inf_indices = np.where(inf_mask)[0]

    print("Contains NaN:", has_nan)
    print("Contains inf:", has_inf)
    print("Zero vector rows found at indices:", zero_indices)
    print("Indices with inf values:", inf_indices)


    clean_arr = embed[~inf_mask]  
    clean_labels = lang_ids[~inf_mask]  

    return clean_arr,clean_labels,inf_indices



def create_language_mapping(dataset, lang_column='lang'):

    all_languages = dataset[lang_column]
    lang_counts = Counter(all_languages)
    unique_langs = sorted(lang_counts.keys())
    lang_to_id = {lang: idx for idx, lang in enumerate(unique_langs)}
    id_to_lang = {idx: lang for lang, idx in lang_to_id.items()}
    

    def add_lang_id(example):
        example['lang_id'] = lang_to_id[example[lang_column]]
        return example
    
    updated_dataset = dataset.map(add_lang_id)
    


    
    return  updated_dataset,lang_to_id


# ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp",trust_remote_code=True)


# python_ds = ds.filter(lambda x: x['lang'] == 'python')
# non_python_ds = ds.filter(lambda x: x['lang'] != 'python')
# python_ds_small = python_ds.shuffle(seed=42).select(range(7000))
# ds = concatenate_datasets([non_python_ds, python_ds_small])

ds = ds.map(create_encodings)
ds = ds.map(remove_lang,with_indices=True)
mbpp = mbpp.map(lambda x: create_encodings(x, p='text', s='code'))


mapped_ds, lang_mapping = create_language_mapping(ds)
lang_ids = np.array([example['lang_id'] for example in mapped_ds])

# prompts_sets = [ds['train']['Encoding1'], ds['train']['solution'], ds['train']['Encoding3']]

# embedding1 = get_embeddings(ds['Encoding3'], batch_size=64, max_length=1024, model_name=model_name)#only sol
# torch.cuda.empty_cache()

# embedding2 = get_embeddings(ds['solution'], batch_size=64, max_length=1024, model_name=model_name)#sol+lang
# torch.cuda.empty_cache()

# embedding3 = get_embeddings(ds['Encoding1'], batch_size=64, max_length=1024, model_name=model_name)#query+sol+lang
# torch.cuda.empty_cache()

train_embedding = get_embeddings(ds['Encoding1'],batch_size=16,max_length =2048,model_name=model_name)
encoding1_arr,encoding1_labels,encoding1_inf_ind = check_inf_vals(train_embedding,lang_ids)

print("Cleaned encoding1 array shape: ", encoding1_arr.shape)
print("Cleaned encoding1 labels shape: ", encoding1_labels.shape)
print("Cleaned encoding1 inf inds: ", encoding1_inf_ind)

np.save(f'{direc_name}/encoding1_arr.npy',encoding1_arr)
np.save(f'{direc_name}/encoding1_labels.npy',encoding1_labels)
np.save(f'{direc_name}/encoding1_inf_ind.npy',encoding1_inf_ind)

test_embedding = get_embeddings(mbpp['test']['Encoding1'],batch_size=32,max_length =1024,model_name=model_name)
torch.cuda.empty_cache()




# encoding2_arr,encoding2_labels,encoding2_inf_ind = check_inf_vals(embedding2,lang_ids)
# encoding3_arr,encoding3_labels,encoding3_inf_ind = check_inf_vals(embedding3,lang_ids)


print("Test embeddings shape: ",test_embedding.shape)
# print("Cleaned encoding2 array shape:", encoding2_arr.shape)
# print("Cleaned encoding2 labels shape:", encoding2_labels.shape)
# print("Cleaned encoding2 inf inds:", encoding2_inf_ind)
# print("Cleaned encoding3 array shape:", encoding3_arr.shape)
# print("Cleaned encoding3 labels shape:", encoding3_labels.shape)
# print("Cleaned encoding3 inf inds:", encoding3_inf_ind)
#8425

np.save(f'{direc_name}/test_encoding.npy',test_embedding)

file_path = f'{direc_name}/lang_mapping.json'
with open(file_path, "w") as json_file:
    json.dump(lang_mapping, json_file, indent=4)
        
# np.save(f'{direc_name}/encoding2_arr.npy',encoding2_arr)
# np.save(f'{direc_name}/encoding2_labels.npy',encoding2_labels)
# np.save(f'{direc_name}/encoding2_inf_ind.npy',encoding2_inf_ind)
# np.save(f'{direc_name}/encoding3_arr.npy',encoding3_arr)
# np.save(f'{direc_name}/encoding3_labels.npy',encoding3_labels)
# np.save(f'{direc_name}/encoding3_inf_ind.npy',encoding3_inf_ind)


del model



def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
