import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser
from gdm.gdm_base import GDMBase
from sentence_transformers import SentenceTransformer
import os

class EmbeddingGDM(GDMBase):
    def __init__(self, cs, dataset_name, mode, model_name="EleutherAI/pythia-6.9b"):
        self.cs = cs
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        self.embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct", trust_remote_code=True, device='cuda:0', cache_folder=os.environ["HF_HOME"])
        self.mode = mode

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']

        actual_members = self.canary_dataset['train']['text']
        actual_nonmembers = []
        for x in tqdm(non_members['text'], desc="Generating member and non-member labels"):
            comparisons = batch_compare(x, members['text'], units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
            if np.max(comparisons) < self.cs:
                actual_nonmembers.append(x)
            else:
                actual_members.append(x)
        
        members = Dataset.from_dict({"text": actual_members})
        non_members = Dataset.from_dict({"text": actual_nonmembers})
        return members, non_members

    def generate_worst_case_neighbor(self, x):
        """
        Here, x is a member already. There is no way to generate a worst case neighbor, so we will simply return x.
        """
        return [x]
    
    def generate_worst_case_non_neighbor(self, x, num_poisons=10):
        """
        Here x is a non-member already.
        """
        x_tokens = self.tokenizer(x, add_special_tokens=False)['input_ids']
        all_poisons = []
        if self.mode == 'token_dropouts':
            for _ in range(num_poisons):
                poison = None
                for d in range(len(x_tokens), 0, -1):        
                    for attempts in range(10):
                        poison_tokens = self.token_dropouts(x_tokens, d=d, option='deterministic')
                        comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
                        if np.max(comparisons) < self.cs:
                            poison = self.tokenizer.decode(poison_tokens)
                            break
                    if poison is not None:
                        all_poisons.append(poison)
                        break
            return all_poisons
        elif self.mode == 'casing_flips':
            for _ in range(num_poisons):
                poison = None
                for p in np.arange(0, 1.0, 0.1):       
                    for attempts in range(10):
                        poison_tokens = self.casing_flips(x_tokens, p=p)
                        comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
                        if np.max(comparisons) < self.cs:
                            poison = self.tokenizer.decode(poison_tokens)
                            break
                    if poison is not None:
                        all_poisons.append(poison)
                        break
            return all_poisons
        elif self.mode == 'chunking':
            for _ in range(num_poisons):
                poison = None
                for c in range(len(x_tokens), 0, -1):   
                    l = c // 2
                    for attempts in range(10):
                        poison_tokens = self.chunking(x_tokens, c=c, l=l)
                        comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
                        if np.max(comparisons) < self.cs:
                            poison = self.tokenizer.decode(poison_tokens)
                            break
                    if poison is not None:
                        all_poisons.append(poison)
                        break
            return all_poisons
        else:
            raise ValueError(f"Unknown mode: {self.mode}. Supported modes are: token_dropouts.")


    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            assert worst_case_neighbors[0] == x
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            if len(worst_case_non_neighbors) > 0:
                assert np.max(batch_compare(x, worst_case_non_neighbors, units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)) < self.cs
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)

        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--cs", type=float)
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--mode", type=str, choices=["token_dropouts", "casing_flips", "chunking"])
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    embedding_gdm = EmbeddingGDM(cs=args.cs, dataset_name=args.dataset_name, mode=args.mode)
    worst_case_dataset = embedding_gdm.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

