import os
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from gmm_stein import GMM_SteinSampler
from openai import OpenAI
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import csv
import re
from score_stein import ScoreSteinSampler
# import vec2text # pip, use pretrained models

import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
import models
from api import load_corrector, invert_embeddings, load_pretrained_corrector
from models.model_utils import device
import torch.nn.functional as F


def concatenate_qa(df_qa):
    df_with_text = df_qa.copy()
    separator = " [SEP] "
    df_with_text['text'] = df_with_text['question'] + separator + df_with_text['answer']
    return df_with_text

def generated_all_embedding(corpus, client):
    # OpenAI
    batch_size = 128  
    corpus_embeddings = []
    for i in range(0, len(corpus), batch_size):
        batch = corpus[i:i+batch_size]
        response = client.embeddings.create(
            model="text-embedding-ada-002", #d=1536
            input=batch
        )
        batch_embeddings = [item.embedding for item in response.data]
        corpus_embeddings.extend(batch_embeddings)
    corpus_embeddings = torch.tensor(corpus_embeddings)
    # ST model
    # model = SentenceTransformer('all-MiniLM-L6-v2')
    # corpus = train_new['text'].tolist()
    # corpus_embeddings = model.encode(corpus, convert_to_tensor=True)
    torch.save(corpus_embeddings, 'corpus_embeddings.pt')

def select_variants(
    target_embedding,
    variants_embedding,
    distance_metric='cos'
):
    if target_embedding.dim() == 1:
        target_embedding = target_embedding.unsqueeze(0)
    
    target_embedding_expanded = target_embedding.expand(variants_embedding.size(0), -1)
    cos_sim = F.cosine_similarity(target_embedding_expanded, variants_embedding, dim=1)
    k = 1
    topk_values, topk_indices = torch.topk(cos_sim, k)
    return topk_indices


def correction_optimization(
    initial_text,
    target_embedding,
    embedding_model,
    ICL_name="gpt2",
    max_iter=1,
    lookahead=True,
    variation_num=3,
): 
    if "gpt2" in ICL_name:
        tokenizer = GPT2Tokenizer.from_pretrained(ICL_name)
        corrector_model = GPT2LMHeadModel.from_pretrained(ICL_name)
        corrector_model.eval()
        
        # add pad_token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        best_candidate = initial_text
        
        for iteration in range(max_iter):
            print(f"Iterate {iteration + 1}/{max_iter}")
            iterative_variants = []
            prompt="""You are a math tutor. Your job is to correct flawed reasoning in math Q&A. Always return only the corrected Q&A, nothing else. Format: <corrected question text> [SEP] <corrected answer>"""
            question, answer = best_candidate.split("[SEP]")
            question, answer = best_candidate.split("[SEP]")
            prompt += f"Q: {question}\nA: {answer}\n\n Corrected Q&A:\n"
            inputs=tokenizer(prompt, return_tensors="pt", truncation=True)
            
            for i_var in range(variation_num):
                with torch.no_grad():
                    outputs = corrector_model.generate(
                        input_ids=inputs["input_ids"],
                        # attention_mask=inputs["attention_mask"],
                        max_new_tokens=inputs["input_ids"].size(1),  # more 200 token?
                        temperature=1.2,
                        top_k=50,
                        top_p=0.9,
                        early_stopping=True,
                        do_sample=True,
                        num_return_sequences=1,
                        repetition_penalty=1,
                        no_repeat_ngram_size=2,
                        pad_token_id=tokenizer.eos_token_id,
                    )
                
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True,clean_up_tokenization_spaces=True)
                print("variation","variation_num",":",generated_text)
                if generated_text.startswith(prompt):
                    generated_text = generated_text[len(prompt):].strip()
                
                iterative_variants.append(generated_text)
                                    # 
            try:
                response = embedding_model.embeddings.create(
                    model="text-embedding-ada-002",
                    input=iterative_variants
                )
                iterative_variants_embedding = torch.tensor([item.embedding for item in response.data])
                
                close_idxes = select_variants(target_embedding, iterative_variants_embedding)
                best_candidate = iterative_variants[close_idxes[0].item()]
                # print(best_candidate)

            except Exception as e:
                print(f"errors happen during generating embeddings: {e}")
                break
    else:
        client = OpenAI(api_key="")
        best_candidate = initial_text
        
        for iteration in range(max_iter):
            print(f"Iterate {iteration + 1}/{max_iter}")
            iterative_variants = []
            # prompt = instruct + best_candidate

            prompt1="""You are a math tutor. Your job is to correct flawed reasoning in math Q&A.  Always output the corrected Q&A in the following exact format. Do not add explanations or extra text. Format: Q: <corrected question text>\n A: <corrected answer>"""
            if "[SEP]" in best_candidate:
                question, answer = best_candidate.split("[SEP]",1)
                if "[SEP]" in answer:
                    answer = answer.replace("[SEP]", " ")  
                prompt2 = f"Q: {question}\nA: {answer}"
            else:
                prompt2 = f"Q&A:{best_candidate}"
            
            for i_var in range(variation_num):
                with torch.no_grad():
                    outputs= client.chat.completions.create(model=ICL_name,
                                                                        messages=[
                                                                            {"role": "system", "content": prompt1},
                                                                            {"role": "user", "content": prompt2}
                                                                            ])
                generated_text = outputs.choices[0].message.content

                match = re.match(r"Q:\s*(.*)\s*A:\s*(.*)", generated_text, re.DOTALL)
                if match:
                    question = match.group(1).strip()
                    answer = match.group(2).strip()
                    generated_text= question + " [SEP] " + answer
                else:
                    if "?" in generated_text:
                        question, answer = generated_text.split("?", 1)
                        question = question.strip() + "?"
                        answer = answer.strip()
                        generated_text= question + " [SEP] " + answer
                    else:
                        # worst case: treat everything as answer
                        question = ""
                        answer = generated_text.strip()
                        generated_text= question + " [SEP] " + answer
                iterative_variants.append(generated_text)
                # 
            try:
                response = embedding_model.embeddings.create(
                    model="text-embedding-ada-002",
                    input=iterative_variants
                )
                iterative_variants_embedding = torch.tensor([item.embedding for item in response.data])
                
                close_idxes = select_variants(target_embedding, iterative_variants_embedding)
                best_candidate = iterative_variants[close_idxes[0].item()]
                # print(best_candidate)

            except Exception as e:
                print(f"errors happen during generating embeddings: {e}")
                break
    
    return best_candidate


def main():
    # Load the original training dataset
    filename = ""
    df = pd.read_csv(filename)
    traindf, evaldf = train_test_split(df, test_size=0.2, random_state=42)
    train_new = concatenate_qa(traindf)
    corpus = train_new['text'].tolist()

    # Get the embeddings
    client = OpenAI(api_key="")
    embedding_path = 'corpus_embeddings.pt'
    
    if os.path.exists(embedding_path):
        print("Embeddings founded and loading...")
        corpus_embeddings = torch.load(embedding_path)
    else:
        print("Embeddings are not founded, generating now...")
        corpus_embeddings = generated_all_embedding(corpus)
    print("Embeddings generated.")

    # Initialize particles
    num_particles = 500
    particle_indices = np.random.choice(len(corpus_embeddings), num_particles, replace=False)
    particles_embeddings = corpus_embeddings[particle_indices].clone().detach().requires_grad_(True)
    particle_embedding_path='particle_embedding.pt'
    if os.path.exists(particle_embedding_path):
        print("Embeddings founded and loading...")
        particles_embeddings = torch.load(particle_embedding_path)
        particles_embeddings=particles_embeddings.to(device)
    else:
        learning_rate = 0.01
        num_iterations = 100
        print("Starting optimization particle embeddings...")
        sampler = GMM_SteinSampler(
            num_particles=num_particles,
            bandwidth=0.5,
            learning_rate=learning_rate,
            num_iterations=num_iterations,
            n_components=10
        )
        particles_embeddings = sampler.fit(corpus_embeddings)
        particles_embeddings = torch.tensor(particles_embeddings)
        torch.save(particles_embeddings, particle_embedding_path)
        print("Optimization complete.")
    inf_budget = True
    if inf_budget:
        inversion_model = models.InversionModel.from_pretrained("").to(device)
        corrector_model = models.CorrectorEncoderModel.from_pretrained("").to(device)
        projector = load_corrector(inversion_model, corrector_model)
        particle_text = invert_embeddings(
                            embeddings=particles_embeddings,
                            corrector=projector
        )  
    else:
        projector = load_pretrained_corrector("text-embedding-ada-002")
        particle_text = invert_embeddings(
                    embeddings=particles_embeddings,
                    corrector=projector
        )

    print("Initial Projection Done.")

    # in-context learning
    error_path = ""
    refined_particle_text = []
    os.makedirs(os.path.dirname(error_path), exist_ok=True)
    model_name="gpt-3.5-turbo"
    save_dir=f""
    os.makedirs(os.path.dirname(save_dir), exist_ok=True)
    with open(save_dir, "a", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["question", "answer"])
        for i in range(len(particle_text)):
            print("Processing",i,'\n',particle_text[i])
            try:
                output_text= correction_optimization(
                    initial_text=particle_text[i],
                    target_embedding=particles_embeddings[i].to('cpu'),
                    embedding_model=client,
                    ICL_name=model_name, # options: "gpt-3.5-turbo","gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"
                    max_iter=2,
                    lookahead=False,
                    variation_num=5,
                )
                print(output_text)
                question, answer = output_text.split("[SEP]",1)
                writer.writerow([question, answer])
            except Exception as e:
                error_msg = f"Error optimizing particle {i}: {e}"
                print(error_msg)
                with open(error_path, "a", encoding="utf-8") as f_err:
                    f_err.write(error_msg + "\n")
                output_text = particle_text[i] 
                question, answer = output_text.split("[SEP]",1)
                writer.writerow([question, answer])              
        
    print("Condensation Done.")
    print("Saving Done.")

if __name__ == "__main__":
    main()