import re
from transformers import GPT2TokenizerFast
from datasets import load_dataset
from itertools import chain
import numpy as np
import torch

import urllib.request
import zipfile
import requests
import json
from datasets import Dataset
import pathlib
from typing import Union
import os
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_from_disk


def cycle_loader(dataloader, sampler=None):
    while 1:
        if sampler is not None:
            sampler.set_epoch(np.random.randint(0, 100000))
        for data in dataloader:
            yield data


def wt_detokenizer(string):
    # contractions
    string = string.replace("s '", "s'")
    string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
    # number separators
    string = string.replace(" @-@ ", "-")
    string = string.replace(" @,@ ", ",")
    string = string.replace(" @.@ ", ".")
    # punctuation
    string = string.replace(" : ", ": ")
    string = string.replace(" ; ", "; ")
    string = string.replace(" . ", ". ")
    string = string.replace(" ! ", "! ")
    string = string.replace(" ? ", "? ")
    string = string.replace(" , ", ", ")
    # double brackets
    string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
    string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
    string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
    string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
    string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
    # miscellaneous
    string = string.replace("= = = =", "====")
    string = string.replace("= = =", "===")
    string = string.replace("= =", "==")
    string = string.replace(" " + chr(176) + " ", chr(176))
    string = string.replace(" \n", "\n")
    string = string.replace("\n ", "\n")
    string = string.replace(" N ", " 1 ")
    string = string.replace(" 's", "'s")
    return string

def ptb_detokenizer(x):
    x = x.replace(" 's", "'s")
    x = x.replace("s ' ", "s' ")
    x = x.replace(" n't", "n't")
    x = x.replace(" \n ", "\n")
    x = x.replace("\\/", "/")
    for _ in range(10):
        x = x.replace(" N ", " 1 ")
    x = x.replace("$ 1", "$1")
    x = x.replace("# 1", "#1")
    x = x.replace("<unk>", "?")
    return x

def lm1b_detokenizer(x):
    x = x.replace('http : / / ', 'http://')
    x = x.replace('https : / / ', 'https://')
    x = re.sub(r' \'(\w+)', r"'\1", x)
    x = re.sub(r' (\w+) \. ', r' \1. ', x)
    x = re.sub(r' (\w+) \.$', r' \1.', x)
    x = x.replace(' ? ', '? ')
    x = re.sub(r' \?$', '?', x)
    x = x.replace(' ! ', '! ')
    x = re.sub(r' \!$', '!', x)
    x = x.replace(' , ', ', ')
    x = x.replace(' : ', ': ')
    x = x.replace(' ; ', '; ')
    x = x.replace(' / ', '/')
    x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
    x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
    x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
    x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
    x = x.replace('$ ', '$')
    x = x.replace('£ ', '£')
    return x


def lambada_detokenizer(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    return '\n'+text.strip()


def get_lambada_test_dataset():
    url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"

    def read_jsonl_to_list(url):
        response = requests.get(url, stream=True)
        data_list = []

        # Process each line in the response content
        for line in response.iter_lines(decode_unicode=True):
            if line:
                data = json.loads(line)
                data_list.append(data)

        return data_list

    lambada_data = read_jsonl_to_list(url)
    dataset = Dataset.from_list(lambada_data)
    return dataset


def get_dataset(name:str, cache_dir: Union[str, pathlib.Path], split: str, seq_len: int = 1024, \
                   tokenizer_hf_name="assets/gpt2-large", num_proc=8):
    load_from_disk('assets/datasets/openwebtext/openwebtext_processed')
    
    if name == 'openwebtext':
        dataset = load_from_disk('assets/datasets/openwebtext/openwebtext_processed')
    elif name == 'wikitext103':
        dataset = load_from_disk('assets/datasets/wikitext/wikitext103_processed')
    else:
        raise ValueError(f"Dataset {name} not supported.")
    return dataset


def get_valid_dataset(name):
    return load_from_disk(f'assets/datasets/{name}')

def get_dataloaders(config, distributed=True):
    if config.training.batch_size % (config.ngpus * config.training.accum) != 0:
            raise ValueError(f"Train Batch Size {config.training.batch_size} is not divisible by {config.ngpus} gpus with accumulation {config.training.accum}.")
    if config.eval.batch_size % (config.ngpus * config.training.accum) != 0:
        raise ValueError(f"Eval Batch Size for {config.eval.batch_size} is not divisible by {config.ngpus} gpus with accumulation {config.training.accum}.")


    train_set = get_dataset(config.train_set.name, cache_dir=config.train_set.cache_dir, split = "train")
    valid_set = get_dataset(config.valid_set.name, cache_dir=config.valid_set.cache_dir, split = "validation" if config.valid_set.name != "text8" else "test")

    if distributed:
        train_sampler = DistributedSampler(train_set) 
        test_sampler = DistributedSampler(valid_set)
    else:
        train_sampler = None
        test_sampler = None
    

    train_loader = cycle_loader(DataLoader(
        train_set,
        batch_size=config.training.batch_size // (config.ngpus * config.training.accum),
        sampler=train_sampler,
        num_workers=4,
        pin_memory=True,
        shuffle=(train_sampler is None),
        persistent_workers=True,
    ))
    valid_loader = cycle_loader(DataLoader(
        valid_set,
        batch_size=config.eval.batch_size // (config.ngpus * config.training.accum),
        sampler=test_sampler,
        num_workers=4,
        pin_memory=True,
        shuffle=(test_sampler is None),
    ))
    return train_loader, valid_loader


def getOpenWebTextDataset(cache_dir: Union[str, pathlib.Path], split: str, seq_len: int, \
                  hf_name="assets/datasets/openwebtext/openwebtext.py", tokenizer_hf_name="assets/gpt2-large", num_proc=8):
    # """
    # seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.
    # context is only used for correct preparation of val/test sets.
    # """
    assert split in ["train", "val", "test"]
    data = load_dataset(hf_name, cache_dir=cache_dir, trust_remote_code=True)

    data = data[split]

    tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_hf_name)
    EOS = tokenizer.encode(tokenizer.eos_token)[0]

    def preprocess_and_tokenize_fn(example):
        text = example["text"]
        tokens = tokenizer(text, return_attention_mask=False)
        # add in EOS token following 
        # https://github.com/jcpeterson/openwebtext/blob/master/tokenize_text.py#L67
        for token in tokens['input_ids']:
            token.append(EOS)
        return tokens
    _cache_file_name = os.path.join(cache_dir, "cache", "tokenized.arrow")
    tokenized_data = data.map(preprocess_and_tokenize_fn, batched=True, num_proc=num_proc, load_from_cache_file=True, \
         cache_file_name=_cache_file_name, new_fingerprint="tokenized", desc="Tokenizing texts")
    tokenized_data = tokenized_data.remove_columns('text')

    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // seq_len) * seq_len
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + seq_len] for i in range(0, total_length, seq_len)]
            for k, t in concatenated_examples.items()
        }
        return result

    _cache_file_name = os.path.join(cache_dir, "cache", "grouped.arrow")
    chunked_data = tokenized_data.map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=True, \
        cache_file_name=_cache_file_name, new_fingerprint="grouped", desc="Grouping tokens")
    chunked_data = chunked_data.with_format('torch')

    return chunked_data

def getWikitext103(cache_dir: Union[str, pathlib.Path], split: str, seq_len: int, \
                  hf_name="wikitext", tokenizer_hf_name="assets/gpt2-large", num_proc=8):
    data = load_dataset(hf_name, name="wikitext-103-raw-v1", cache_dir=cache_dir)
    data = data[split]

    detokenizer = wt_detokenizer

    def _apply_detokenizer(detokenizer):
        def detok(text):
            for i, t in enumerate(text, 0):
                 text[i] = detokenizer(t)
            return text
        return detok

    tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_hf_name)
    EOS = tokenizer.encode(tokenizer.eos_token)[0]

    def preprocess_and_tokenize(example):
        text = example["text"]
        
        if detokenizer is not None:
            text = _apply_detokenizer(detokenizer)(text)

        tokens = tokenizer(text, return_attention_mask=False)
        # add in EOS token following 
        # https://github.com/jcpeterson/openwebtext/blob/master/tokenize_text.py#L67
        for token in tokens['input_ids']:
            token.append(EOS)
        return tokens
    _cache_file_name = os.path.join(cache_dir, "cache", "tokenized.arrow")
    tokenized_data = data.map(preprocess_and_tokenize, batched=True, num_proc=num_proc, load_from_cache_file=True, \
        cache_file_name=_cache_file_name, new_fingerprint="tokenized", desc="Tokenizing texts")
    tokenized_data = tokenized_data.remove_columns('text')

    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // seq_len) * seq_len
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + seq_len] for i in range(0, total_length, seq_len)]
            for k, t in concatenated_examples.items()
        }
        return result

    _cache_file_name = os.path.join(cache_dir, "cache", "grouped.arrow")
    chunked_data = tokenized_data.map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=True, \
        cache_file_name=_cache_file_name, new_fingerprint="grouped", desc="Grouping tokens")
    chunked_data = chunked_data.with_format('torch')
    return chunked_data


def get_valid_dataloaders(args, distributed=True):
    if args.batch_size % args.ngpus != 0:
        raise ValueError(f"Eval Batch Size for {args.batch_size} is not divisible by {args.ngpus} gpus.")

    dataset = get_valid_dataset(args.valid_dataset)

    if distributed:
        sampler = DistributedSampler(dataset)
    else:
        sampler = None
    
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size // args.ngpus,
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
        shuffle=(sampler is None),
        persistent_workers=True,
    )
    return dataloader