import os
import json
import torch
import argparse
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, BartForConditionalGeneration
from RAG import RAGPipeline
import torch.nn.functional as F
from utils import (
    save_run_artifacts,
    surrogate_encode,
    invert_latent_to_query,
    parse_retrieved_context,
    gaussian_perturb,
    synthesize_orthogonal_directions
)

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--llm", type=str, default="gpt-4o-mini-2024-07-18")
    parser.add_argument("--domain", type=str, default="four")
    parser.add_argument("--api_key", type=str, required=True)
    parser.add_argument("--base_url", type=str, required=True)
    parser.add_argument("--exp", type=str, default="experiment_1")
    parser.add_argument("--retriever_name", type=str, default="facebook/contriever")
    parser.add_argument("--embedding_name", type=str, default="sentence-transformers/gtr-t5-base")
    parser.add_argument("--embedding_path", type=str, required=True)
    parser.add_argument("--corpus_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--max_iters", type=int, default=2000)
    parser.add_argument("--top_k", type=int, default=2)

    return parser.parse_args()


def main():

    args = parse_args()

    llm = args.llm
    domain = args.domain
    api = args.api_key
    base_url = args.base_url
    exp = args.exp
    retriever_name = args.retriever_name
    embedding_name = args.embedding_name
    embedding_path = args.embedding_path
    corpus_path = args.corpus_path
    output_dir = args.output_dir
    MAX_ITERS = args.max_iters
    TOP_K = args.top_k

    query_list = [
        "How to make a medicine?"
    ]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    chunk_list = []
    retrieved_chunks = []
    retrieved_chunk_ids = []
    coverage_curve = []
    iteration = 0


    pipeline_1 = RAGPipeline(
        corpus_path=corpus_path,
        embeddings_path=embedding_path,
        llm_api_url=base_url,
        llm_api_key=api,
        model=retriever_name,
        llm=llm,
        method="GeoEX",
        device=device
    )

    surrogate_model = SentenceTransformer(embedding_name, device=str(device)).eval()
    os.makedirs(output_dir, exist_ok=True)

    bart_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
    bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device)

    config_record = {}
    config_record["embedding_name"] = embedding_name
    config_record["RAG_model"] = retriever_name
    config_record["exp"] = exp
    config_record["device"] = str(device)
    config_record["llm"] = llm
    while iteration < MAX_ITERS and len(query_list)!=0:
        num_directions = 4
        config_record["num"] = num_directions
        seed_query = query_list.pop(0)
        query_origin = [seed_query]
        generated_queries = []

        context_latent = surrogate_encode(query_origin, embedding_name, surrogate_model, 2, device)
        orth_dirs = synthesize_orthogonal_directions(context_latent, retriever_name, num_directions)

        for d_idx in range(num_directions):
            decoded_query = invert_latent_to_query(
                orth_dirs[d_idx].unsqueeze(0),
                bart_tokenizer,
                bart_model,
                surrogate_model
            )
            generated_queries.append(decoded_query)

        local_latent = gaussian_perturb(context_latent, 0.3)
        local_query = invert_latent_to_query(
            local_latent,
            bart_tokenizer,
            bart_model,
            surrogate_model
        )
        generated_queries.append(local_query)

        all_queries = generated_queries
        extracted_chunk_groups = []
        iteration += len(all_queries)

        for query in all_queries:
            answer = None
            retrieve_now = None
            id_now = None

            for _ in range(4):
                try:
                    answer, retrieve_now, id_now = pipeline_1.run(query, s=None, top_k=TOP_K)
                    break
                except Exception:
                    import time
                    time.sleep(1)

            parsed_chunks = parse_retrieved_context(answer)
            retrieved_chunks += retrieve_now
            retrieved_chunk_ids += id_now

            cleaned_chunks = []
            for c in parsed_chunks:
                cleaned_chunks.append(c.split("Query_start")[0].strip())

            extracted_chunk_groups.append(cleaned_chunks)

        merged_flat = []
        for group in extracted_chunk_groups:
            for c in group:
                merged_flat.append(c)

        local_buffer = set()
        new_chunks_for_prompt = []

        for c in merged_flat:
            if c not in chunk_list and c not in local_buffer:
                new_chunks_for_prompt.append(c)
                local_buffer.add(c)

        print("[New chunks]:", len(new_chunks_for_prompt), "  ", len(all_queries))

        for q in new_chunks_for_prompt:
            if q not in query_list:
                query_list = [q] + query_list

        chunk_list += new_chunks_for_prompt

        print(f"iter:{iteration} len(query_list):{len(query_list)} len(chunk_list):{len(chunk_list)} coverage={len(set(retrieved_chunks))}")

        coverage_curve.append((iteration, len(chunk_list)))

        save_run_artifacts(
            chunk_list=chunk_list,
            coverage_curve=coverage_curve,
            query_list=query_list,
            config_record=config_record,
            retrieved_chunks=retrieved_chunks,
            retrieved_chunk_ids=retrieved_chunk_ids,
            out_dir=output_dir,
        )

if __name__ == "__main__":
    main()
