import json

import torch
import torchvision.transforms as transforms
from tqdm import tqdm


def get_dataset(dname, **kwargs):
    # if dname == "agnews":
    #     train_set, test_set, glove = get_torchtext_dataset(dname, max_seq_len=32, max_vocab_size=5000)
    #     return train_set, test_set, glove
    if dname == "mnist":
        img_transforms = [
            # transforms.Resize(8, antialias=True),
            transforms.ToTensor(), 
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
        train_set, test_set = get_torchvision_dataset(dname, 
                                                      img_transforms)
    elif dname == "fashion_mnist":
        img_transforms = [
            # transforms.Resize(8, antialias=True),
            transforms.ToTensor(), 
            transforms.Normalize((0.5,), (0.5,)),
        ]
        train_set, test_set = get_torchvision_dataset(dname, 
                                                      img_transforms)
    elif dname == "svhn":
        train_img_transforms = [
            transforms.RandomPerspective(distortion_scale=0.2),
            transforms.ToTensor(),
            transforms.Resize(16, antialias=True)
            ]
        test_img_transforms = [
            transforms.ToTensor(),
            transforms.Resize(16, antialias=True)
        ]
        train_set, test_set = get_torchvision_dataset(dname, 
                                                      train_img_transforms, 
                                                      test_img_transforms)
    elif dname == "cifar10":
        train_img_transforms = [
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
            transforms.ToTensor(),
        ]
        test_img_transformers = [transforms.ToTensor()]
        train_set, test_set = get_torchvision_dataset(dname, 
                                                      train_img_transforms, 
                                                      test_img_transformers)
    elif dname in ["imdb", "samsum", "xsum", "agnews"]:
        tokenizer = kwargs.get("tokenizer")
        train_set, test_set = get_huggingface_dataset(dname, tokenizer)
    return train_set, test_set

import torchvision

def get_torchvision_dataset(dname, train_transforms, test_transforms=None):
    if dname == "mnist": data_init = torchvision.datasets.MNIST
    if dname == "fashion_mnist": data_init = torchvision.datasets.FashionMNIST
    if dname == "svhn": data_init = torchvision.datasets.SVHN
    if dname == "cifar10": data_init = torchvision.datasets.CIFAR10
    if test_transforms is None:
        test_transforms = train_transforms
    if dname == "svhn":
        train_set = data_init(root="../data", 
                              split="train", 
                              download=True, 
                              transform=transforms.Compose(train_transforms))
        test_set = data_init(root="../data", 
                             split="test", 
                             download=True, 
                             transform=transforms.Compose(test_transforms))
    else:
        train_set = data_init(root="../data", 
                              train=True, 
                              download=True, 
                              transform=transforms.Compose(train_transforms))
        test_set = data_init(root="../data", 
                             train=False, 
                             download=True, 
                             transform=transforms.Compose(test_transforms))
    return train_set, test_set

import torchtext
from sklearn.feature_extraction.text import TfidfVectorizer

def get_glove(corpus, min_df=10, max_df=0.8, max_vocab_size=5000, **kwargs):
    # Filtering out dummy words
    tfidf = TfidfVectorizer(
        stop_words="english", min_df=min_df, max_df=max_df, 
        max_features=max_vocab_size-1, **kwargs)
    tfidf.fit(corpus) 
    vocab = list(tfidf.get_feature_names_out())
    vocab.extend(["<unk>"])
    vocab_ids = {v: i for i, v in enumerate(vocab)}

    # Get GloVe matrix for the filtered vocab
    glove_vocab = torchtext.vocab.GloVe(name="twitter.27B", dim=25)
    emb_weights = glove_vocab.get_vecs_by_tokens(vocab)
    glove = torch.nn.Embedding(*emb_weights.shape)
    glove.load_state_dict({"weight": emb_weights})
    glove.weight.requires_grad = False
    print("vocab size:", len(vocab))
    print("no. stop words:", len(tfidf.stop_words_))

    return glove, vocab_ids

def convert_examples_to_glove_features(examples, glove, glove_ids, max_seq_len):
    tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
    samples = []
    for e in tqdm(examples, desc="get_glove_features"):
        text, label = e[1], e[0]
        label = int(label) - 1
        tokens = tokenizer(text)
        input_ids = [glove_ids[t] for t in tokens if t in glove_ids]
        input_ids = input_ids[:max_seq_len]     # truncation
        pad_len = max(0, max_seq_len - len(input_ids))
        input_ids.extend([glove_ids["<unk>"]] * pad_len)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        feats = glove(input_ids)
        samples.append((feats, label))

    return samples

def get_torchtext_dataset(dname, max_vocab_size, max_seq_len):
    if dname == "agnews": data_init = torchtext.datasets.AG_NEWS
    train_set = data_init(split="train")
    test_set = data_init(split="test")
    corpus = [e[1] for e in train_set]
    glove, glove_ids = get_glove(corpus, max_vocab_size=max_vocab_size)
    train_set = convert_examples_to_glove_features(train_set, glove, glove_ids, max_seq_len)
    test_set = convert_examples_to_glove_features(test_set, glove, glove_ids, max_seq_len)
    return train_set, test_set, glove

def convert_examples_to_ids(examples, tokenizer, max_seq_len):
    samples = []
    for e in tqdm(examples, desc="get_input_ids"):
        new_sample = tokenizer(e["text"], padding="max_length", max_length=max_seq_len, truncation=True)
        new_sample["labels"] = e["label"]
        samples.append(new_sample)

    return samples

def preprocess_samsum(examples, tokenizer, max_seq_len):
    prefix = "summarize: "
    # import pdb; pdb.set_trace()
    samples = []
    for e in tqdm(examples, desc="get_input_ids"):
        new_sample = tokenizer(prefix + e["dialogue"], padding="max_length", max_length=max_seq_len, truncation=True)
        new_sample["labels"] = tokenizer(text_target=e["summary"], padding="max_length", max_length=max_seq_len, truncation=True)["input_ids"]
        samples.append(new_sample)
        # import pdb; pdb.set_trace()
    return samples

def preprocess_samsum_GPT2(examples, tokenizer, max_seq_len):
    prefix = "summarize: "
    # import pdb; pdb.set_trace()
    samples = []
    for e in tqdm(examples, desc="get_input_ids"):
        new_sample = {}
        new_sample["text"] = prefix + e["dialogue"]
        new_sample["labels"] = e["summary"]
        samples.append(new_sample)
        # import pdb; pdb.set_trace()
    return samples

def preprocess_xsum(examples, tokenizer, max_seq_len):
    prefix = "summarize: "
    
    samples = []
    for e in tqdm(examples, desc="get_input_ids"):
        new_sample = tokenizer(prefix + e["document"], padding="max_length", max_length=max_seq_len, truncation=True)
        new_sample["labels"] = tokenizer(text_target=e["summary"], padding="max_length", max_length=max_seq_len, truncation=True)["input_ids"]
        
        samples.append(new_sample)

    return samples

def save_xsum(examples, tokenizer, max_seq_len):
    prefix = "summarize: "
    unlearn_order_path = "sequential_unlearning/instance_results/xsum_123/unlearn_order.pth"
    unlearn_indices = torch.load(open(unlearn_order_path, "rb"))
    
    samples = []
    for i, e in tqdm(enumerate(examples), desc="get_input_ids"):
        new_sample = tokenizer(prefix + e["document"], padding="max_length", max_length=max_seq_len, truncation=True)
        new_sample["labels"] = tokenizer(text_target=e["summary"], padding="max_length", max_length=max_seq_len, truncation=True)["input_ids"]
        
        samples.append(new_sample)
        if i in unlearn_indices:
            with open("xsum_unlearn.jsonl", "a", encoding="utf-8") as file:
                e["document"] = prefix + e["document"]
                json_str = json.dumps(e["document"])
                file.write(json_str + "\n")
        else:
            with open("xsum_train.jsonl", "a", encoding="utf-8") as file:
                e["document"] = prefix + e["document"]
                json_str = json.dumps(e["document"])
                file.write(json_str + "\n")

    return samples

from datasets import load_dataset

def get_huggingface_dataset(dname, tokenizer, max_seq_len=1024):
    if dname=="agnews":
        dname="ag_news"
    dataset = load_dataset(dname)
    if dname == "samsum":
        # train_dataset = preprocess_samsum_GPT2(dataset["train"], tokenizer, max_seq_len)
        # test_dataset = preprocess_samsum_GPT2(dataset["test"], tokenizer, max_seq_len)
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
    elif dname == "ag_news":
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]
        # train_dataset = convert_examples_to_ids(dataset["train"], tokenizer, max_seq_len)
        # test_dataset = convert_examples_to_ids(dataset["test"], tokenizer, max_seq_len)
    elif dname == "xsum":
        train_dataset = preprocess_xsum(dataset["train"], tokenizer, max_seq_len)
        test_dataset = preprocess_xsum(dataset["test"], tokenizer, max_seq_len)
    else:
        train_dataset = convert_examples_to_ids(dataset["train"], tokenizer, max_seq_len)
        test_dataset = convert_examples_to_ids(dataset["test"], tokenizer, max_seq_len)
    return train_dataset, test_dataset

if __name__ == "__main__":
    get_dataset("agnews")