import os
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from gmm_stein import GMM_SteinSampler
from openai import OpenAI
import torch
from sklearn.cluster import KMeans
# import datasets

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import csv
import re
from score_stein import ScoreSteinSampler
# import vec2text # pip, use pretrained models

import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
import models
from api import load_corrector, invert_embeddings, load_pretrained_corrector
from models.model_utils import device
import torch.nn.functional as F


def generated_all_embedding(corpus, client,embedding_save_path,task_prompt):
    # OpenAI
    batch_size = 128  
    corpus_embeddings = []
    for i in range(0, len(corpus), batch_size):
        batch = corpus[i:i+batch_size]
        batch = [task_prompt + " " + text for text in batch]
        response = client.embeddings.create(
            model="text-embedding-ada-002", #d=1536
            input=batch
        )
        batch_embeddings = [item.embedding for item in response.data]
        corpus_embeddings.extend(batch_embeddings)
    corpus_embeddings = torch.tensor(corpus_embeddings)
    torch.save(corpus_embeddings, embedding_save_path)
    return corpus_embeddings

def select_variants(
    target_embedding,
    variants_embedding,
    distance_metric='cos'
):
    if target_embedding.dim() == 1:
        target_embedding = target_embedding.unsqueeze(0)
    
    target_embedding_expanded = target_embedding.expand(variants_embedding.size(0), -1)
    cos_sim = F.cosine_similarity(target_embedding_expanded, variants_embedding, dim=1)
    k = 1
    topk_values, topk_indices = torch.topk(cos_sim, k)
    return topk_indices


def correction_optimization(
    initial_text,
    target_embedding,
    embedding_model,
    ICL_name="gpt2",
    task_gudie=None,
    promt_prefix=None,
    max_iter=1,
    lookahead=True,
    variation_num=3,
): 
    if "gpt2" in ICL_name:
        tokenizer = GPT2Tokenizer.from_pretrained(ICL_name)
        corrector_model = GPT2LMHeadModel.from_pretrained(ICL_name)
        corrector_model.eval()
        
        # add pad_token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        best_candidate = initial_text
        
        for iteration in range(max_iter):
            print(f"Iterate {iteration + 1}/{max_iter}")
            iterative_variants = []
            prompt=task_gudie+promt_prefix+best_candidate
            inputs=tokenizer(prompt, return_tensors="pt", truncation=True)
            for i_var in range(variation_num):
                with torch.no_grad():
                    outputs = corrector_model.generate(
                        input_ids=inputs["input_ids"],
                        # attention_mask=inputs["attention_mask"],
                        max_new_tokens=inputs["input_ids"].size(1),  # more 200 token?
                        temperature=1.2,
                        top_k=50,
                        top_p=0.9,
                        early_stopping=True,
                        do_sample=True,
                        num_return_sequences=1,
                        repetition_penalty=1,
                        no_repeat_ngram_size=2,
                        pad_token_id=tokenizer.eos_token_id,
                    )
                
                generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True,clean_up_tokenization_spaces=True)
                print("variation","variation_num",":",generated_text)
                if generated_text.startswith(prompt):
                    generated_text = generated_text[len(prompt):].strip()
                
                iterative_variants.append(generated_text)
                                    # 
            try:
                response = embedding_model.embeddings.create(
                    model="text-embedding-ada-002",
                    input=iterative_variants
                )
                iterative_variants_embedding = torch.tensor([item.embedding for item in response.data])
                
                close_idxes = select_variants(target_embedding, iterative_variants_embedding)
                best_candidate = iterative_variants[close_idxes[0].item()]
                # print(best_candidate)

            except Exception as e:
                print(f"errors happen during generating embeddings: {e}")
                break
    else:
        client = OpenAI(api_key="")
        best_candidate = initial_text
        for iteration in range(max_iter):
            print(f"Iterate {iteration + 1}/{max_iter}")
            iterative_variants = []
            prompt=promt_prefix+best_candidate+"\n Please output only the polished result."
            for i_var in range(variation_num):
                with torch.no_grad():
                    outputs= client.chat.completions.create(model=ICL_name,
                                                                        messages=[
                                                                            {"role": "system", "content": task_gudie},
                                                                            {"role": "user", "content": prompt}
                                                                            ])
                generated_text = outputs.choices[0].message.content
                iterative_variants.append(generated_text)
                # 
            try:
                response = embedding_model.embeddings.create(
                    model="text-embedding-ada-002",
                    input=iterative_variants
                )
                iterative_variants_embedding = torch.tensor([item.embedding for item in response.data])
                
                close_idxes = select_variants(target_embedding, iterative_variants_embedding)
                best_candidate = iterative_variants[close_idxes[0].item()]
                # print(best_candidate)

            except Exception as e:
                print(f"errors happen during generating embeddings: {e}")
                break
    
    return best_candidate

def bulid_prompt(dataset_name,category):
    if dataset_name=="ag_news":
        task_gudie="You will be given a piece of News text, The text may be grammatically incorrect, awkward, incomplete, or unnatural. Your task is to polish and rewrite the text."
        promt_prefix=f"Please polish the following text into fluent, coherent English that reads like a professional {category} news report, completing unclear expressions while preserving the original meaning.\n Text:\n"  
    elif dataset_name=="imdb":
        task_gudie="You will be given a piece of IMDB movie review, The review may be grammatically incorrect, awkward, incomplete, or unnatural. Your task is to polish and rewrite the review. "
        promt_prefix=f"Please polish the following review into fluent, coherent English that reads like a {category} review, completing unclear expressions while preserving the original meaning.\n Review:\n"  
    elif dataset_name=="sst2":
        task_gudie="You will be given a piece of sentence in movie review, The sentence may be incorrect, awkward, incomplete, or unnatural. Your task is to polish and rewrite the it. "
        promt_prefix=f"Please polish the following sentence into fluent, coherent English that reads like convey a {category} sentiment, rewrite the unclear expressions while preserving the original words and meaning as much as possible.\n Sentence:\n"  
    return task_gudie, promt_prefix
def initial_dataset_setting(dataset_name):
    if dataset_name=="ag_news":
        num_particels_per_class=30
        Task_specific_prompt="Read the following news article and classify it into one of our categories: World, Sports, Business, or Science/Technology. Provide a brief rationale for your classification."
        label_to_cateory={0:"World",1:"Sports",2:"Business",3:"Sci/Tech"}
    elif dataset_name=="imdb":
        Task_specific_prompt="Read the following IMDB review and classify it as either positive or negative. Provide a biref rationale for your classification."
        num_particels_per_class=10
        label_to_cateory={0:"negative",1:"possitive"}
    elif dataset_name=="sst2":
        Task_specific_prompt="Read the following sentences and classify it as either positive or negative sentiment. Provide a biref rationale for your classification."
        num_particels_per_class=40
        label_to_cateory={0:"negative",1:"possitive"}
    return num_particels_per_class,label_to_cateory, Task_specific_prompt


def main():
    # Load the original training dataset
    dataset_name="sst2"
    if dataset_name=="sst2":
        dataset_dict = load_dataset("stanfordnlp/sst2")
        dataset_dict = dataset_dict.rename_column("sentence", "text")
        dataset_dict = dataset_dict.remove_columns("idx")
    else:
        dataset_dict = load_dataset(dataset_name)
    corpus=dataset_dict["train"]['text']
    labels=dataset_dict["train"]['label']
    num_particels_per_class, label_to_cateory,task_specifc_prompt = initial_dataset_setting(dataset_name)



    # Get the embeddings
    client = OpenAI(api_key="")
    embedding_path = f'corpus_embeddings_{dataset_name}.pt'
    taskpt=True
    if os.path.exists(embedding_path):
        print("Embeddings founded and loading...")
        corpus_embeddings = torch.load(embedding_path)
    else:
        if taskpt:
            print("Embeddings are not founded, generating now...")
            corpus_embeddings = generated_all_embedding(corpus,client,embedding_path,task_specifc_prompt)
        else:
            corpus_embeddings = generated_all_embedding(corpus,client,embedding_path," ")
        print(corpus_embeddings.size())
    print("Embeddings generated.")

    labels = torch.tensor(labels)
    classes = torch.unique(labels)
    num_classes = len(classes)
    distilled_texts, distilled_labels = [], []
    method_name=""
    if method_name=="random":
        for c in classes.tolist():
            mask = (labels == c)
            class_corpus_text = np.array(corpus)[mask] 
            class_particle_indices = np.random.choice(len(class_corpus_text), num_particels_per_class, replace=False)
            class_corpus_text=class_corpus_text[class_particle_indices]
            distilled_texts.extend(class_corpus_text)
            distilled_labels.extend([int(c)] * len(class_corpus_text))
    elif method_name=="clustering":
        inversion_model = models.InversionModel.from_pretrained(""+dataset_name).to(device)
        corrector_model = models.CorrectorEncoderModel.from_pretrained(""+dataset_name).to(device)
        projector = load_corrector(inversion_model, corrector_model)
        for c in classes.tolist():
            mask = (labels == c)
            class_corpus_embeddings = corpus_embeddings[mask]
            kmeans = KMeans(n_clusters=num_particels_per_class,
                random_state=42).fit(class_corpus_embeddings)
            class_particles_embeddings = kmeans.cluster_centers_
            class_particles_embeddings = torch.tensor(class_particles_embeddings).to(device)
            class_particle_text = invert_embeddings(
                                embeddings=class_particles_embeddings,
                                corrector=projector)
            # print(class_particle_text)
            distilled_texts.extend(class_particle_text)
            distilled_labels.extend([int(c)] * len(class_particle_text))
    elif method_name=="clustering_v":
        inversion_model = models.InversionModel.from_pretrained(""+dataset_name).to(device)
        corrector_model = models.CorrectorEncoderModel.from_pretrained(""+dataset_name).to(device)
        projector = load_corrector(inversion_model, corrector_model)
        for c in classes.tolist():
            class_particles_embeddings=[]
            num_per_model=5
            mask = (labels == c)
            class_corpus_embeddings = corpus_embeddings[mask]
            kmeans = KMeans(n_clusters=int(num_particels_per_class/num_per_model),
                random_state=42).fit(class_corpus_embeddings)
            class_centers= kmeans.cluster_centers_
            class_labels=kmeans.labels_
            for i in range(int(num_particels_per_class/num_per_model)):
                class_cluster_points = class_corpus_embeddings[class_labels == i]
                cov = np.cov(class_cluster_points.T) + 1e-6 * np.eye(class_cluster_points.shape[1])
                mean = class_centers[i]
                sampled = np.random.multivariate_normal(mean, cov, size=num_per_model)
                class_particles_embeddings.append(sampled)
                # particles_embeddings = torch.tensor(particles_embeddings).to(device)
            class_particles_embeddings = np.vstack(class_particles_embeddings)
            class_particles_embeddings = torch.tensor(class_particles_embeddings).to(device)
            class_particle_text = invert_embeddings(
                                embeddings=class_particles_embeddings,
                                corrector=projector)
            print(len(class_particle_text))
            distilled_texts.extend(class_particle_text)
            distilled_labels.extend([int(c)] * len(class_particle_text))
    elif method_name=="Ours":
        inversion_model = models.InversionModel.from_pretrained(""+dataset_name).to(device)
        corrector_model = models.CorrectorEncoderModel.from_pretrained(""+dataset_name).to(device)
        projector = load_corrector(inversion_model, corrector_model)
        for c in classes.tolist():
            mask = (labels == c)
            category=label_to_cateory[int(c)]
            class_corpus_embeddings = corpus_embeddings[mask]
            task_gudie,promt_prefix=bulid_prompt(dataset_name,category)
            print(task_gudie+promt_prefix)
            learning_rate = 0.01
            num_iterations = 100
            bandwith=0.5
            print("Starting optimization particle embeddings...")
            sampler = GMM_SteinSampler(
                num_particles=num_particels_per_class,
                bandwidth=bandwith,
                learning_rate=learning_rate,
                num_iterations=num_iterations,
                n_components=10
            )
            class_particles_embeddings = sampler.fit(class_corpus_embeddings)
            class_particles_embeddings = torch.tensor(class_particles_embeddings)
            class_particle_text = invert_embeddings(
                                embeddings=class_particles_embeddings,
                                corrector=projector)
            correct_class_text=[]
            for i in range(len(class_particle_text)):
                output_text= correction_optimization(
                        initial_text=class_particle_text[i],
                        target_embedding=class_particles_embeddings[i].to('cpu'),
                        embedding_model=client,
                        task_gudie=task_gudie,
                        promt_prefix=promt_prefix,
                        ICL_name="gpt-3.5-turbo", # options: "gpt-3.5-turbo","gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"
                        max_iter=4,
                        lookahead=False,
                        variation_num=5,
                    )
                print(output_text)
                correct_class_text.append(output_text)
            # print(correct_class_text)
            distilled_texts.extend(correct_class_text)
            distilled_labels.extend([int(c)] * len(class_particle_text))
    elif method_name=="Stein_pure":
        inversion_model = models.InversionModel.from_pretrained(""+dataset_name).to(device)
        corrector_model = models.CorrectorEncoderModel.from_pretrained(""+dataset_name).to(device)
        projector = load_corrector(inversion_model, corrector_model)
        for c in classes.tolist():
            mask = (labels == c)
            class_corpus_embeddings = corpus_embeddings[mask]
            learning_rate = 0.01
            num_iterations = 100
            bandwith=0.5
            print("Starting optimization particle embeddings...")
            sampler = GMM_SteinSampler(
                num_particles=num_particels_per_class,
                bandwidth=bandwith,
                learning_rate=learning_rate,
                num_iterations=100,
                n_components=40
            )
            class_particles_embeddings = sampler.fit(class_corpus_embeddings)
            class_particles_embeddings = torch.tensor(class_particles_embeddings)
            class_particle_text = invert_embeddings(
                                embeddings=class_particles_embeddings,
                                corrector=projector)
            print(class_particle_text)
            distilled_texts.extend(class_particle_text)
            distilled_labels.extend([int(c)] * len(class_particle_text))
    df = pd.DataFrame({
    "text": distilled_texts,
    "label": distilled_labels
})
    num_particles=num_particels_per_class*num_classes
    output_file =f"/{dataset_name}/sample_{num_particles}_{method_name}_withtaskp.csv"
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    df.to_csv(output_file, index=False, encoding="utf-8-sig")
    print(f"Distilled dataset saved to {output_file}, total size: {len(df)}")




if __name__ == "__main__":
    main()