from datasets import load_dataset
import random
import torch
from torch.utils.data import DataLoader, IterableDataset
import tiktoken
from typing import Tuple, List
import torch
import math

def next_power_of_2(x): return 2 ** math.ceil(math.log2(x))
def flatten(list_of_lists): return [item for sublist in list_of_lists for item in sublist]
DS_CACHE = dict()
MAX_LABEL_LENGTH = 41

RANDOM_NEEDLE_CITIES  = list(set([
    'Chicago', 'Yangon', 'Antananarivo', 'Colombo', 'Almaty', 'Sydney', 'Chicago', 'Mexico City',
    'Seattle', 'Lagos', 'Amsterdam', 'Belgrade', 'Cairo', 'Baghdad', 'Damascus', 'Kigali', 'Dakar',
    'Dakar', 'Sofia', 'Kigali', 'Victoria', 'Tashkent', 'Mumbai', 'Barcelona', 'Almaty', 'Amman',
    'Toronto', 'Bratislava', 'Johannesburg', 'Thimphu', 'Bangkok', 'Santiago', 'Cairo', 'San Francisco',
    'Lagos', 'Amsterdam', 'Paris', 'Rabat', 'Santiago', 'Copenhagen', 'Madrid', 'Kigali',
    'Ho Chi Minh City', 'Sarajevo', 'Delhi', 'Istanbul', 'Ho Chi Minh City', 'Khartoum', 'Helsinki',
    'Doha', 'Istanbul', 'Kuala Lumpur', 'Budapest', 'Shanghai', 'Moscow', 'Los Angeles', 'Oslo',
    'Johannesburg', 'Berlin', 'Bangalore', 'Tokyo', 'Melbourne', 'Barcelona', 'Chicago', 'Port Louis',
    'Lisbon', 'Nairobi', 'Kampala', 'Lima', 'Maputo', 'Vancouver', 'Dubai', 'Khartoum', 'Jakarta',
    'Madrid', 'Yerevan', 'Beirut', 'Athens', 'Chicago', 'Paris', 'Bucharest', 'Copenhagen', 'Brussels',
    'Damascus', 'Seattle', 'Los Angeles', 'Yerevan', 'Victoria', 'Tunis', 'Astana', 'Seoul',
    'Buenos Aires', 'Bangkok', 'Colombo', 'Brussels', 'Khartoum', 'Doha', 'San Francisco', 'Vienna', 'Jakarta']))

class C4NeedleTrainEval:
    def __init__(self, tokenizer, batch_size=8,
                  ctx_len=64*1024,
                  do_train=False):
        self.tokenizer = tokenizer
        self.BATCH_SIZE = batch_size
        self.ctx_len = ctx_len
        assert ctx_len % 1024 == 0, "ctx_len must be a multiple of 1024"
        assert ctx_len >= 1024, "ctx_len must be at least 1024"
        self.do_train = do_train
        self.tmp_eos_token = -999
        self.tmp_bos_token = -998
        if do_train:
            if 'train' not in DS_CACHE:
                ds = load_dataset(
                    "allenai/c4",
                    data_files={'train': ['en/c4-train.00000-of-*.json.gz', 'en/c4-train.00001-of-*.json.gz', 'en/c4-train.00002-of-*.json.gz'] },
                    split='train',
                    streaming=False
                ).select(range(1_000_000))
                DS_CACHE['train'] = ds
            ds = DS_CACHE['train']
        else:
            if 'val' not in DS_CACHE:
                ds = load_dataset(
                    "allenai/c4",
                    data_files={'validation': ['en/c4-validation.00000-of-*.json.gz', 'en/c4-validation.00001-of-*.json.gz']},
                    split='validation',
                    streaming=False
                ).select(range(50_000))
                DS_CACHE['val'] = ds
            ds = DS_CACHE['val']
        assert not isinstance(ds, self.StreamingDataset)
        seed_val = 1 + int(str(ctx_len//1024)+str(batch_size))
        seed_train = int(str(ctx_len//1024)+str(batch_size) + str(int(do_train)))
        assert seed_val != seed_train, "seeds must be different"
        seed = seed_train if do_train else seed_val
        self.dataset = self.StreamingDataset(ds,
                                            self.preprocess_function,
                                            ctx_len=self.ctx_len,
                                            tokenizer=self.tokenizer,
                                            shuffle = do_train,
                                            seed=seed)

        self.dataloader = None
        self.masked_prompt_dataloader = None
        seed = seed + 1
        self.rng = random.Random(seed)

    # Streaming dataset wrapper for training
    class StreamingDataset(IterableDataset):
        def __init__(self, dataset, preprocess_fn, ctx_len=None, tokenizer=None, shuffle=None, seed=None):
            self.tokenizer = tokenizer
            self.dataset = dataset
            self.preprocess_fn = preprocess_fn
            self.ctx_len = ctx_len
            assert shuffle is not None, "shuffle must be specified"
            if shuffle:
                self.rng = random.Random(seed)
            self.shuffle = shuffle
            self.cache_shuffle_size = 1

        def __iter__(self):
            tokenized_examples = []
            total_size = 0
            cache = []
            for example in self.dataset:
                size = len(self.tokenizer(example["text"])['input_ids'][0])
                tokenized_examples.append(example)
                total_size += size
                if total_size >= self.ctx_len:
                    _example = dict(text='. '.join([x["text"] for x in tokenized_examples]))
                    tokenized_examples = []
                    total_size = 0
                    cache.append(self.preprocess_fn(_example))

                if len(cache) >= self.cache_shuffle_size: # flush the cache. involves shuffling and yielding everything
                    if self.shuffle:
                        self.rng.shuffle(cache)
                    for example in cache:
                        yield example
                    cache = []

    def sample_number(self):
        return self.rng.randint(1_000_000, 9_999_999)
    def sample_cities(self, n=3):
        ## this doesn't seem reproducible...
        # random_cities = self.rng.sample(RANDOM_NEEDLE_CITIES, n)
        # return random_cities
        ## but this is!...:
        ixs = set()
        while len(ixs) < n:
            ixs.add(self.rng.randint(0, len(RANDOM_NEEDLE_CITIES)-1))
        ixs = list(ixs)
        return [RANDOM_NEEDLE_CITIES[ix] for ix in ixs]

    def sample_magic_number_sentences(self, n=3):
        magic_numbers = [self.sample_number() for _ in range(n)]
        random_cities = self.sample_cities(n=n)
        magic_number_sentences = [
            f"The special magic {random_cities[ix]} number is {magic_numbers[ix]}."
            for ix in range(n)
        ]
        return magic_number_sentences, random_cities, magic_numbers
    def get_prompt_and_answers(self, text) -> Tuple[str, List[str]]:
        n_sentences = 3
        magic_number_sentences, cities, numbers = self.sample_magic_number_sentences(n=n_sentences)
        prompt_pre = f"""The context contains sentences pertaining to special magic cities and associated numbers. Please search the context and remember these.
Context: """
        prompt_pre_tokens = self.tokenizer(prompt_pre, add_special_tokens=False)['input_ids'][0]
        prompt_post = f"""
Question: Can you write down the special magic cities and their numbers?
Answer: """
        prompt_post_tokens = self.tokenizer(prompt_post, add_special_tokens=False)['input_ids'][0]

        ## construct label and label tokens
        expected_answers = [f"{city}={number}" for city, number in zip(cities, numbers)]
        label = "; ".join(expected_answers) + '.'
        label_tokens = self.tokenizer(label, add_special_tokens=False)['input_ids'][0] + [self.tmp_eos_token]
        assert len(label_tokens) <= MAX_LABEL_LENGTH, f"Label tokens are too long! len(label_tokens) = {len(label_tokens)}, MAX_LABEL_LENGTH = {MAX_LABEL_LENGTH}"

        ## construct needle tokens
        needle_tokens = [self.tokenizer(x, add_special_tokens=False)['input_ids'][0] for x in magic_number_sentences]

        # calculate number of text tokens we can use
        n_text_tokens = self.ctx_len - len(prompt_pre_tokens) - len(prompt_post_tokens) - MAX_LABEL_LENGTH - sum(len(x) for x in needle_tokens)
        if self.do_train: # if we are training, we need to add one more token because we shift the labels
            n_text_tokens += 1
        assert n_text_tokens > 0, "Not enough tokens for text!"

        # convert text to tokens
        text_tokens = self.tokenizer(text, add_special_tokens=False)['input_ids'][0]
        text_tokens = text_tokens[:n_text_tokens] # make sure we have exactly the correct number of tokens

        # Insert the magic number sentences at uniformly random positions
        positions = set()
        while len(positions) < n_sentences:
            pos = self.rng.randint(0, len(text_tokens) - 1)
            positions.add(pos)
        positions = sorted(list(positions))

        new_text_tokens = []
        for ix, pos in enumerate(positions):
            if ix == 0:
                new_text_tokens.append(text_tokens[:pos])
            else:
                new_text_tokens.append(text_tokens[positions[ix-1]:pos])
            new_text_tokens.append(needle_tokens[ix])
        new_text_tokens.append(text_tokens[positions[-1]:])
        assert len(text_tokens) + sum(len(x) for x in needle_tokens) ==  sum(len(x) for x in new_text_tokens) , "length changed when adding needles..."

        needled_text_tokens = flatten(new_text_tokens)
        input_ids = prompt_pre_tokens + needled_text_tokens + prompt_post_tokens
        return dict(
            input_text=self.tokenizer.decode(input_ids),
            input_ids=input_ids,
            label_text=label,
            label_tokens=label_tokens,
            expected_answers=expected_answers,
            cities=cities,
            numbers=numbers,
            start_ixs=len(input_ids),
        )

    def preprocess_function(self, examples):
        if not self.do_train:
            text = examples["text"]
            data = self.get_prompt_and_answers(text)
            return {
                "input_ids": data['input_ids'],
                "start_ixs": data['start_ixs'],
                "raw_input": data['input_text'],
                "expected_raw_output": data['label_text'],
                "cities": data['cities'],
                "numbers": data['numbers']
            }
        else:
            # For training data (streaming), we process one example at a time
            text = examples["text"]

            # Create prompt and answer for this example
            data = self.get_prompt_and_answers(text)

            # # Combine prompt and answer tokens
            # combined_tokens = prompt_tokens + answer_tokens
            combined_tokens = torch.tensor(data['input_ids'] + data['label_tokens'])
            # # pad to get to ctx_len
            padding_required = self.ctx_len - len(combined_tokens) + 1
            combined_tokens = torch.cat([combined_tokens, torch.full((padding_required,), self.tokenizer.pad_token_id, dtype=torch.long)])
            input_ids = combined_tokens

            # Create labels
            labels = input_ids.clone()
            labels[input_ids == self.tokenizer.pad_token_id] = -100

            # Replace tmp tokens
            labels[input_ids == self.tmp_eos_token] = self.tokenizer.eos_token_id
            input_ids[input_ids == self.tmp_eos_token] = self.tokenizer.eos_token_id
            labels[input_ids == self.tmp_bos_token] = self.tokenizer.bos_token_id
            input_ids[input_ids == self.tmp_bos_token] = self.tokenizer.bos_token_id

            attention_mask = torch.ones(*input_ids.shape)
            attention_mask = (input_ids != self.tokenizer.pad_token_id)

            # Mask prompt tokens for the labels
            len_prompt_tokens = len(data['input_ids'])
            labels[:len_prompt_tokens] = -100

            # shift so that labels and inputs are aligned
            input_ids = input_ids[:-1]
            labels = labels[1:]
            attention_mask = attention_mask[:-1]

            return {
                "input_ids": input_ids,
                "labels": labels,
                "attention_mask": attention_mask,
                "cities": data['cities'],
                "numbers": data['numbers'],
            }

    def custom_collate_fn(self, features):
        keys = features[0].keys()

        batch = {}
        max_length = max(len(f['input_ids']) for f in features)
        for key in keys:
            # Handle input_ids and labels with dynamic padding
            if key in ['input_ids', 'labels', 'attention_mask']:
                # Pad tensors to the same length
                padded_tensors = []
                for f in features:
                    tensor = f[key] if isinstance(f[key], torch.Tensor) else torch.tensor(f[key])
                    if key == 'input_ids':
                        pad_value = self.tokenizer.pad_token_id
                    elif key == 'labels':
                        pad_value = -100
                    else:  # attention_mask
                        pad_value = 0

                    # Create padding
                    padding = torch.full(
                        (max_length - tensor.shape[0],),
                        pad_value,
                        dtype=tensor.dtype
                    )

                    # Concatenate the original tensor with padding
                    padded_tensor = torch.cat([tensor, padding])
                    padded_tensors.append(padded_tensor)

                # Stack the padded tensors
                batch[key] = torch.stack(padded_tensors)
            else:
                # For any other keys, use default stacking
                if isinstance(features[0][key], torch.Tensor):
                    batch[key] = torch.stack([f[key] for f in features])
                else:
                    batch[key] = [f[key] for f in features]
        return batch


    def get_dataloader(self):
        """Get the dataloader."""
        if self.dataloader is None:
            if self.do_train:
                self.dataloader = DataLoader(
                    self.dataset,
                    batch_size=self.BATCH_SIZE,
                    collate_fn=self.custom_collate_fn,
                )
            else:
                self.dataloader = DataLoader(
                    self.dataset,
                    batch_size=self.BATCH_SIZE,
                    collate_fn=self.custom_collate_fn,
                    shuffle=False
                )
        return self.dataloader

    def get_dataset(self):
        """Get the raw dataset."""
        return self.dataset

class TiktokenWrapper:
    """
    A wrapper for tiktoken to provide an interface similar to HuggingFace tokenizers.
    """
    def __init__(self, encoding_name="gpt2"):
        self.enc = tiktoken.get_encoding(encoding_name)
        # Get the special EOT token
        self.eos_token_id = self.enc._special_tokens['<|endoftext|>']
        self.bos_token_id = self.enc._special_tokens['<|endoftext|>'] ## GPT2 sets bos/eos as the same...
        self.pad_token_id = self.eos_token_id  # Using EOS as PAD token

    def __call__(self, texts, add_special_tokens=True):
        """
        Tokenize a list of texts or a single text.
        """
        if isinstance(texts, str):
            texts = [texts]

        input_ids = []
        for text in texts:
            tokens = self.enc.encode_ordinary(text)
            input_ids.append(tokens)

        return {'input_ids': input_ids}

    def decode(self, token_ids):
        """
        Decode a list of token IDs to a string.
        """
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()
        res = self.enc.decode(token_ids)
        return res