import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc

import random
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

random.seed(42)
np.random.seed(42)
set_seed(42)
# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,nargs = '+',help='GPU ID',default = [0])
parser.add_argument('--momentum_constant',type=float,help='momentum constant for proxies update',default=0.99)
parser.add_argument('--manifold_m',type=int,help='manifold dims and n neighbors',default=3)
parser.add_argument('--N_beta',type=float,help='similarity calculation',default=0.5)
parser.add_argument('--N_alpha',type=float,help='similarity calculation',default=4)
parser.add_argument('--delta_manifold',type=float,help='similarity calculation',default=2)
parser.add_argument('--reconstruction_threshold',type=float,help='momentum constant for proxies update',default=0.9)
parser.add_argument('--alpha',type=int,help='alpha proxy anchor loss',default=32)
parser.add_argument('--delta_pca',type=float,help='delta for proxy anchor loss',default=0.1)


args = parser.parse_args()
model_name = args.model_name
gpu_id = args.device_id
momentum = args.momentum_constant
m = args.manifold_m
N_beta = args.N_beta
N_alpha = args.N_alpha
delta = args.delta_manifold
reconstruction_threshold = args.reconstruction_threshold
alpha = args.alpha
delta_pca = args.delta_pca

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_id))

device = 'cuda'

model_data_path = model_name.split('/')[1]


direc_name = f''
new_direc_name = direc_name+f'/abilations'
os.makedirs(direc_name,exist_ok = True)
os.makedirs(new_direc_name,exist_ok = True)
train_vec = np.load(f'{direc_name}/encoding1_arr.npy')
train_lab = np.load(f'{direc_name}/encoding1_labels.npy')

with open(f'{direc_name}/lang_mapping.json','r') as file:
    lang_mapping = json.load(file)

###################### paths #############################################
def format_param_name(params: dict, precision: int = 3) -> str:

    parts = []
    for key, value in params.items():
        if isinstance(value, float):
            val_str = f"{value:.{precision}f}".replace(".", "_")
        else:
            val_str = str(value)
        parts.append(f"{key}{val_str}")
    return "_".join(parts)


params = {
    "momentum": momentum,
    "m": m,
    "N_beta": N_beta,
    "N_alpha": N_alpha,
    "delta": delta,
    "reconstruction_threshold": reconstruction_threshold,
    "alpha": alpha,
    "delta_pca": delta_pca
}

param_str = format_param_name(params)


prots_path = f"{direc_name}/abilations/prot_{param_str}.npy"

model_save_path = f"{direc_name}/abilations/model_{param_str}.pth"
generations_prot_save_path = f"{direc_name}/abilations/prot_p10_{param_str}.jsonl"


###############################paths ##############################################################
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,

    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

    
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'

with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)



############################### getting new prot inds #####################

feat = train_vec
lab = train_lab

feat_img = list(feat)
labels=list(lab)

class DuelCNNWrapper(nn.Module):
    def __init__(self,vec_len):
        super(DuelCNNWrapper, self).__init__()
        
        self.additional_layer = nn.Sequential(
            nn.Linear(vec_len, vec_len),
            nn.InstanceNorm1d(vec_len),
            nn.ReLU()

           
        
        )

    def forward(self, x):
         
        x = self.additional_layer(x) 
        return x
        
vec_len = feat.shape[1]

model = DuelCNNWrapper(vec_len).to('cuda')
model.load_state_dict(torch.load(model_save_path))

class CustomDataset(Dataset):
    def __init__(self, img, labels, transform=None):
        self.img = img
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        # Fetch the image and label corresponding to the index
        img = self.img[idx]
        label = self.labels[idx]
        
        # Apply the transformation if provided
        if self.transform:
            img = self.transform(img)
        
        return img, label

train_dataset = CustomDataset( feat_img, labels)
indices = list(range(len(train_dataset)))

indices_shuffled = torch.randperm(len(indices)).tolist()

shuffled_dataset = torch.utils.data.Subset(train_dataset, indices_shuffled)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False,num_workers = 4,drop_last = True)
new_feat = []
new_labels= []

for img,label in tqdm(train_loader):
    
    img = img.to('cuda')
    new_arr = model(img)
    new_arr = new_arr.detach().cpu().numpy()
    # new_labels = new_labels.detach().cpu().numpy()

    for arr in new_arr:
        new_feat.append(arr)

    for lab in label:
        new_labels.append(lab)

new_arr = np.array(new_feat)
new_labels = np.array(new_labels)

inds = []
nn_human = np.load(prots_path)


for i, vec in enumerate(nn_human):
    # Compute L2 distances between vec and all feat vectors
    dists = np.linalg.norm(new_arr - vec, axis=1)
    closest_idx = np.argmin(dists)

    inds.append(closest_idx)

few_shots = [(ds['problem'][ind],ds['solution'][ind]) for ind in inds]
#########################getting new prot inds ######################################





def build_icl_prompt(few_shots,test_problem):
    prompt = ""
    for prob, sol in few_shots:
        # prompt+= f"You are an expert Python programmer, and here is your task: {prob} \n[BEGIN]\n{sol}\n[DONE]"
        prompt += f"You are an expert programmer, and here is your task: {prob}\n[BEGIN]\n{sol}\n[DONE]\n\n"

    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt


test_problems = mbpp_subset
completions = {}


prompts, task_ids = [], []
for item in test_problems:
    task_ids.append(str(item["task_id"]))
    prompts.append(build_icl_prompt(few_shots,item["text"]))

batch_size = 1   # tune for your VRAM
num_return_sequences = 10
completions = {}

for i in tqdm(range(0, len(prompts), batch_size), desc="Generating (batched)"):
    batch_prompts = prompts[i:i+batch_size]
    batch_task_ids = task_ids[i:i+batch_size]

    outs = generator(
        batch_prompts,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
        num_return_sequences=num_return_sequences,
        batch_size=len(batch_prompts),  # must match current batch length
    )

    # outs: List[List[Dict]] → [batch_size][num_return_sequences]
    for j, tid in enumerate(batch_task_ids):
        completions[tid] = [o["generated_text"] for o in outs[j]]


with open(generations_prot_save_path, "w") as f:
    for task_id, sols in completions.items():
        for s in sols:
            f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": s}) + "\n")



print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
