import argparse

import faiss 
import json
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, PeftModel
import torch.nn.functional as F

def parse_args():
    parser = argparse.ArgumentParser(description="Retrieve script with configurable parameters")
    parser.add_argument("--tokenizer_path", type=str, required=True, help="Path to the tokenizer")
    parser.add_argument("--base_model_name", type=str, required=True, help="Name of the base model")
    parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA model")
    parser.add_argument("--cache_dir", type=str, required=True, help="Path to the cache directory")
    parser.add_argument("--retrieve_num", type=int, required=True, help="Number of items to retrieve")
    parser.add_argument("--source_path", type=str, required=True, help="Path to the source file")
    parser.add_argument("--output_path", type=str, required=True, help="Path to the output file")
    parser.add_argument("--index_path", type=str, required=True, help="Path to the index file")
    parser.add_argument("--corpus_path", type=str, required=True, help="Path to the corpus file")

    return parser.parse_args()

def load_model_and_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
    base_model = AutoModel.from_pretrained(args.base_model_name, cache_dir=args.cache_dir, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    if base_model.config.pad_token_id is None:
            base_model.config.pad_token_id = 0
    lora_config = LoraConfig.from_pretrained(args.lora_path, cache_dir=args.cache_dir, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
    lora_model = PeftModel.from_pretrained(base_model, args.lora_path, config=lora_config)
    lora_model = lora_model.merge_and_unload()
    lora_model.to("cuda")
    lora_model.eval()
    return tokenizer, lora_model

def last_token_pool(last_hidden_states, attention_mask):
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_embeddings(model, tokenizer, query_data, type):
    max_length = 610 if type == "query" else 210
    if isinstance(query_data, str):
        query_data = [query_data]
    
    n = len(query_data)
    all_embs = []

    batch_size = 8

    with torch.no_grad(), torch.amp.autocast('cuda'):
        for start in range(0, n, batch_size):
            batch_texts = query_data[start:start + batch_size]

            batch_dict = tokenizer(
                batch_texts,
                max_length=max_length - 1,
                return_attention_mask=False,
                padding=False,
                truncation=True,
                return_token_type_ids=False,
                add_special_tokens=True
            )

            batch_dict['input_ids'] = [
                ids + [tokenizer.eos_token_id] for ids in batch_dict['input_ids']
            ]

            batch_dict = tokenizer.pad(
                batch_dict,
                padding=True,
                return_attention_mask=True,
                return_tensors='pt'
            ).to('cuda')

            outputs = model(**batch_dict, return_dict=True)
            
            embs = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            embs = F.normalize(embs, p=2, dim=1)

            all_embs.append(embs)

            del batch_dict, outputs

    embeddings = torch.cat(all_embs, dim=0)
    print(embeddings.shape)
    return embeddings

def main():
    args = parse_args()
    
    tokenizer, model = load_model_and_tokenizer(args)

    query_data = []
    with open(args.source_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            query_data.append(data.get("query"))

    corpus_data = []
    with open(args.corpus_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            corpus_data.append(data)

    query_embeddings = get_embeddings(model, tokenizer, query_data, "query").cpu().detach().numpy()
    index = faiss.read_index(args.index_path)
    num_gpus = faiss.get_num_gpus()
    if num_gpus == 0:
        print("No GPU found or using faiss-cpu. Back to CPU.")
    else:
        print(f"Using {num_gpus} GPU")
        if num_gpus == 1:
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            res = faiss.StandardGpuResources()
            index = faiss.index_cpu_to_gpu(res, 0, index, co)
        else:
            co = faiss.GpuMultipleClonerOptions()
            co.shard = True
            co.useFloat16 = True
            index = faiss.index_cpu_to_all_gpus(index, co, ngpu=num_gpus)

    all_scores, all_indices = index.search(query_embeddings, args.retrieve_num)

    all_related_theorems = []
    for query_results in all_indices:
        related_theorems = []
        for idx in query_results:
            related_theorems.append(corpus_data[idx])
        all_related_theorems.append(related_theorems)

    with open(args.output_path, 'w') as fout, open(args.source_path, 'r') as f_source:
        assert len(all_related_theorems) == len(query_data)
        for source, query_results in zip(f_source, all_related_theorems):
            source_data = json.loads(source)
            retrieved_data = []
            for i, query_result in enumerate(query_results):
                retrieved_data.append({
                    "id": i,
                    "formal_statement": query_result.get("text"),
                    "informal_statement": query_result.get("query"),
                })
            result = {"query": source_data.get("query"), "retrieved_statements": retrieved_data}
            fout.write(json.dumps(result, ensure_ascii=False) + '\n')

if __name__ == "__main__":
    main()
