import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser
from gdm.gdm_base import GDMBase

class NgramGDM(GDMBase):
    def __init__(self, k, dataset_name, mode, model_name="EleutherAI/pythia-6.9b"):
        self.k = k
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        self.mode = mode

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']

        actual_members = self.canary_dataset['train']['text']
        actual_nonmembers = []
        for x in tqdm(non_members['text'], desc="Generating member and non-member labels"):
            comparisons = batch_compare(x, members['text'], units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
            if np.max(comparisons) < self.k:
                actual_nonmembers.append(x)
            else:
                actual_members.append(x)
        
        members = Dataset.from_dict({"text": actual_members})
        non_members = Dataset.from_dict({"text": actual_nonmembers})
        return members, non_members

    def generate_worst_case_neighbor(self, x):
        """
        Here, x is a member already. There is no way to generate a worst case neighbor, so we will simply return x.
        """
        return [x]
    
    def generate_worst_case_non_neighbor(self, x, num_poisons=10):
        """
        Here x is a non-member already.
        """
        x_tokens = self.tokenizer(x, add_special_tokens=False)['input_ids']
        all_poisons = []
        if self.mode == 'token_dropouts':
            for _ in range(num_poisons):
                poison = None
                for d in range(len(x_tokens), 0, -1):        
                    for attempts in range(10):
                        poison_tokens = self.token_dropouts(x_tokens, d=d, option='deterministic')
                        comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
                        if np.max(comparisons) < self.k:
                            poison = self.tokenizer.decode(poison_tokens)
                            break
                    if poison is not None:
                        all_poisons.append(poison)
                        break
            return all_poisons
        elif self.mode == 'casing_flips':
            for _ in range(num_poisons):
                poison = None
                for p in np.arange(0, 1.0, 0.01):       
                    for attempts in range(10):
                        poison_tokens = self.casing_flips(x_tokens, p=p)
                        comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
                        if np.max(comparisons) < self.k:
                            poison = self.tokenizer.decode(poison_tokens)
                            break
                    if poison is not None:
                        all_poisons.append(poison)
                        break
            return all_poisons
        elif self.mode == 'chunking':
            for _ in range(num_poisons):
                poison = None
                c = self.k - 1     
                l = c // 2
                for attempts in range(10):
                    poison_tokens = self.chunking(x_tokens, c=c, l=l)
                    comparisons = batch_compare(x, [self.tokenizer.decode(poison_tokens)], units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
                    if np.max(comparisons) < self.k:
                        poison = self.tokenizer.decode(poison_tokens)
                        break
                if poison is not None:
                    all_poisons.append(poison)
            return all_poisons
        else:
            raise ValueError(f"Unknown mode: {self.mode}. Supported modes are: token_dropouts.")


    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            assert worst_case_neighbors[0] == x
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            if len(worst_case_non_neighbors) > 0:
                assert np.max(batch_compare(x, worst_case_non_neighbors, units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)) < self.k
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)

        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--k", type=int)
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--mode", type=str, required=True, choices=["token_dropouts", "casing_flips", "chunking"])
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    ngram_gdm = NgramGDM(k=args.k, dataset_name=args.dataset_name, mode=args.mode)
    worst_case_dataset = ngram_gdm.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

