import os
import torch
import random
import datasets
import itertools
from typing import Dict
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from pytorch_lightning import LightningDataModule
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
from datasets import load_dataset, Dataset
import copy
import json
from collections import defaultdict
import numpy as np

from src.conv_util import create_template
from src.dataset import ToFU
from src.dataset.forgetpretrain import ForgetPretrain

def create_datamod(configs, data_mode_config, tokenizer=None, **kwargs):
    class_name = configs.get('class_name', None)
    if "ToFU".lower() in class_name.lower():
        return ToFU_DataModule(
            **configs, 
            **data_mode_config,
            tokenizer=tokenizer,
            **kwargs,
        )
    elif 'Harry'.lower() in class_name.lower():
        return HarryPotterDataModule(
            **configs,
            **data_mode_config,
            tokenizer=tokenizer,
            **kwargs,
        )
    else:
        raise ValueError("Unkown data module class")

class ToFUTorchDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, data_collator, max_length=500, question_end_token=None, forget_length=None, retain_length=None):
        super(ToFUTorchDataset, self).__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.data_collator = data_collator
        self.max_length = max_length
        self.question_end_token = question_end_token #! proide question_end_token to mask out question tokens in the labels
        self.forget_length = forget_length
        self.retain_length = retain_length
    
    def __len__(self):
        return len(self.data)

    #! Follow
    def tokenize_text(self, text):
        inputs = self.tokenizer(
            text, return_tensors="pt", padding="max_length", max_length=self.max_length, truncation=True,
        )
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask'][0]
        collated = self.data_collator([input_ids[0]])
        input_ids = collated['input_ids'][0]
        labels = collated['labels'][0]
        labels[attention_mask.sum()] = self.tokenizer.pad_token_id
        
        if self.question_end_token is None:
            #! Include question part in labels, which is used for remember losses
            pass
        else:
            question_text = text.split(self.question_end_token)[0] + self.question_end_token
            question_token_ids = self.tokenizer(
                question_text, return_tensors='pt', padding='do_not_pad', truncation=True, max_length=self.max_length
            ).input_ids[0]
            if not len(question_token_ids) == self.max_length:
                labels[:len(question_token_ids)] = -100
            
        return input_ids, attention_mask, labels

    def __getitem__(self, idx):
        item = self.data[idx]
        if isinstance(item, tuple):
            qa, retainlabel = item
        else:
            qa, retainlabel = item, 0 # by default, it's forget data
        
        if isinstance(qa, str):
            qa = [qa]
        
        result = {}
        for name, text in zip(["", "prefer_"] ,qa):
            input_ids, attention_mask, labels = self.tokenize_text(text)
            result[f"{name}input_ids"] = input_ids
            result[f"{name}attention_mask"] = attention_mask
            result[f"{name}labels"] = labels
        
        if retainlabel is not None: 
            result['retainlabels'] = retainlabel
        return result 
        
def collect_expand_data(
    expand_qanum=10, path="./para_outputs-gpt/paraphrase_data/forget10_perturbed/paraphrase_res.csv",
):
    res = []
    import pandas as pd
    df = pd.read_csv(path)
    for idx, line in df.iterrows():
        para_question = list(set(eval(line.iloc[2])))
        para_answer = list(set(eval(line.iloc[3])))
        tmpres = []
        for para_q in para_question:
            for para_a in para_answer:
                tmpres.append((para_q, para_a))
        tmpres = tmpres[:expand_qanum]
        res.extend(tmpres)
    print("Expand num: ", len(res))
    return res

def collect_perturb_data(
    expand_qanum=10, path="./para_outputs-gpt/paraphrase_data/forget10_perturbed/perturb_res.csv",
):
    res = []
    import pandas as pd
    df = pd.read_csv(path)
    for idx, line in df.iterrows():
        para_question = line.iloc[2]
        para_answer = list(set(eval(line.iloc[3])))
        tmpres = []
        for para_a in para_answer:
            tmpres.append((para_question, para_a))
        tmpres = tmpres[:expand_qanum]
        res.extend(tmpres)
    print("Perturb num: ", len(res))
    return res

class ToFU_DataModule(LightningDataModule):
    """
    Things to consider: 
    Make sure all data is in ToFU format
    Text part:
        1. incorporate retain data (add label)
        2. incorporate perturb data (add label)
    Torch part:
        1. consider prefer labels
    """
    def __init__(
        self, 
        tokenizer,
        conv_template, 
        split, 
        max_length=256, 
        batch_size=8, 
        with_retain=False, retain_num=400, with_idk=False, with_dpo=False, 
        expand_forget=False, with_perturb=False, # Our method
        question_end_token=None,
        **kwargs,
    ):
        super().__init__()

        conv_template = create_template(conv_template, tokenizer=tokenizer)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.question_end_token = question_end_token

        self.forget_eval = ToFU('locuslab/TOFU', split=split, template_func=conv_template.prepare_prompt)
        self.retain_eval = ToFU('locuslab/TOFU', split='retain_perturbed', template_func=conv_template.prepare_prompt)
        perturb_eval = load_dataset('locuslab/TOFU', split)['train']
        def flatten_perturb(perturb_dataset):
            for sample in perturb_dataset:
                perturb_answer_list = sample.pop('perturbed_answer')
                newsample = copy.deepcopy(sample)
                for perturb_ans in perturb_answer_list[:1]:
                    newsample['answer'] = perturb_ans
                    yield newsample
        self.perturb_eval = datasets.Dataset.from_generator(flatten_perturb, gen_kwargs={"perturb_dataset": perturb_eval})
        self.perturb_eval = ToFU(data=self.perturb_eval, template_func=conv_template.prepare_prompt, answer_key='answer')
        self.paraphrase_eval = ToFU('locuslab/TOFU', 
            split=split, template_func=conv_template.prepare_prompt, answer_key='paraphrased_answer', 
        )
        self.with_retain = with_retain | with_perturb
        
        # Construct training 
        base_forget_data = load_dataset('locuslab/TOFU', split)['train']
        #! Add forget data
        if expand_forget:
            print("Adding forget data")
            expand_qanum = kwargs.get('expand_qanum', 2)
            if expand_qanum > 0:
                expand_qa = collect_expand_data(
                    expand_qanum=expand_qanum, path=kwargs.get('paraphrase_path'),
                )
                tmpdata = datasets.Dataset.from_list([{'question': q, 'answer': a} for q, a in expand_qa])
            else:
                #! Otherwise we copy the original forget data
                tmpdata = load_dataset('locuslab/TOFU', split)['train']
            base_forget_data = datasets.concatenate_datasets([base_forget_data, tmpdata])
        self.forget_length = len(base_forget_data)

        if with_retain:
            print("Adding retain data")
            retain_split = "retain" + str(100 - int(split.split("_")[0].replace("forget", ""))).zfill(2)
            retain_train = load_dataset('locuslab/TOFU', retain_split)['train']
            retain_train = retain_train.select(
                range(len(retain_train) - retain_num, len(retain_train))
            )
            base_forget_data = datasets.concatenate_datasets([base_forget_data, retain_train])

        if with_perturb:
            print("Adding perturb data")
            perturb_qa = collect_perturb_data( 
                expand_qanum=kwargs.get('expand_qanum', 3),
                path=kwargs.get('perturb_path')
            )
            tmpdata = datasets.Dataset.from_list([{'question': q, 'answer': a} for q, a in perturb_qa])
            base_forget_data = datasets.concatenate_datasets([base_forget_data, tmpdata])
        
        if self.with_retain:
            base_forget_data = base_forget_data.add_column(
                'retain_label', [0] * self.forget_length + [1] * (len(base_forget_data) - self.forget_length)
            )
        
        if with_idk:
            forget_answer_key = 'idk'
        else:
            forget_answer_key = 'answer'
        
        self.forget_train = ToFU(data=base_forget_data, answer_key=forget_answer_key, template_func=conv_template.prepare_prompt, as_dpo=with_dpo)
    
    def to_torch_dataset(self, data, add_length=False):
        data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        if add_length:
            torchdataset = ToFUTorchDataset(data, self.tokenizer, data_collator, self.max_length, question_end_token=self.question_end_token, forget_length=self.forget_length, retain_length=len(data) - self.forget_length)
        else:
            torchdataset = ToFUTorchDataset(data, self.tokenizer, data_collator, self.max_length, question_end_token=self.question_end_token, forget_length=self.forget_length, retain_length=len(data) - self.forget_length)
        return torchdataset

    def to_loader(self, data, shuffle=True):
        torchdataset = self.to_torch_dataset(data)
        return DataLoader(
            torchdataset,
            batch_size=self.batch_size, 
            shuffle=shuffle, 
            num_workers=16,
        )
    
    def train_set(self):
        return self.to_torch_dataset(self.forget_train, True)

    def val_set(self):
        return {
            "forget": self.to_torch_dataset(self.forget_eval),
            "retain": self.to_torch_dataset(self.retain_eval),
            "perturb": self.to_torch_dataset(self.perturb_eval),
            "paraphrase": self.to_torch_dataset(self.paraphrase_eval),
        }

    def train_dataloader(self):
        return self.to_loader(self.forget_train)

    def val_dataloader(self) -> TRAIN_DATALOADERS:
        return [
            self.to_loader(self.forget_eval, shuffle=False),
            self.to_loader(self.retain_eval, shuffle=False),
            self.to_loader(self.perturb_eval, shuffle=False),
            self.to_loader(self.paraphrase_eval, shuffle=False),
        ] 
    
    def stats(self):
        return {
            "train": {"forget num": self.forget_length, "retain num": len(self.forget_train) - self.forget_length, "forget mode": self.forget_train.answer_key, "dpo mode": self.forget_train.as_dpo},
            "val": {
                "forget num": len(self.forget_eval),
                "retain num": len(self.retain_eval),
                "perturb num" : len(self.perturb_eval),
                "paraphrase num": len(self.paraphrase_eval),
            }
        }

from torch.utils.data import Sampler
class EqualForgetRetainSampler(Sampler):
    #! This sampler interleaves sample from forget subset (0 - forget_length) and retain subset (forget_length - (forget_length + retain_length))
    # used for deepspeed forget-retain training, otherwise the training would fail for invalidate source error (https://github.com/microsoft/DeepSpeed/discussions/4081)

    def __init__(self, forget_length, retain_length, generator=None):
        self.forget_length = forget_length
        self.retain_length = retain_length
        self.generator = generator
    
    def balanced_interleave(self, shorter, longer):
        if len(shorter) > len(longer):
            shorter, longer = longer, shorter
        if len(shorter) == 0: # no need to interleave
            return torch.tensor(longer)

        ratio = len(longer) / len(shorter)
        result = []
        long_idx = 0
        for s in shorter:
            result.append(s)
            steps = round(ratio)
            result.extend(longer[long_idx:long_idx+steps])
            long_idx += steps
        return torch.tensor(result)

    def __iter__(self):
        forget_indices = torch.randperm(self.forget_length, generator=self.generator)
        retain_indices = torch.randperm(self.retain_length, generator=self.generator) + self.forget_length
        interleaved_indices = self.balanced_interleave(forget_indices, retain_indices).tolist()
        return iter(interleaved_indices)

    def __len__(self):
        return self.forget_length + self.retain_length


class FullTorchDataset(torch.utils.data.Dataset):
    # conv_template can prepare_gen_prompt or prepare_prompt
    def __init__(self, data, tokenizer, conv_template, max_length=500, forget_length=None, retain_length=None, dpo_mode=False):
        super(FullTorchDataset, self).__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.data_collator = DataCollatorForLanguageModeling(mlm=False, tokenizer=self.tokenizer)
        self.conv_template = conv_template
        self.forget_length = forget_length
        self.retain_length = retain_length
        self.max_length = max_length

        self.dpo_mode = dpo_mode
        self.alternative_responses = [json.loads(x) for x in open('data/refusal.jsonl').readlines()]

    def __len__(self):
        return len(self.data)

    def tokenize_text(self, item : Dict):
        prefix_text = self.conv_template.prepare_gen_prompt(**item)
        full_text = self.conv_template.prepare_prompt(**item)
        inputs = self.tokenizer(
            full_text, 
            return_tensors='pt', 
            padding='max_length', 
            max_length=self.max_length, 
            truncation=True
        )
        input_ids = inputs['input_ids'][0]
        attention_mask = inputs['attention_mask'][0]
        collated = self.data_collator([input_ids])
        if self.tokenizer.padding_side == 'right':
            prefix_num = len(self.tokenizer(prefix_text).input_ids)
            labels = collated['labels'][0]
            labels[:prefix_num] = -100
        else:
            prefix_num = len(self.tokenizer(prefix_text).input_ids)
            suffix_num = attention_mask.sum().item() - prefix_num
            labels = collated['labels'][0]
            labels[:-suffix_num] = -100
            
        return input_ids, attention_mask, labels
    
    def __getitem__(self, idx):
        item = self.data[idx] # {question: , answer: }
        real_items = [item]
        if self.forget_length is not None:
            retainlabel = 0 if idx < self.forget_length else 1
        else:
            retainlabel = 0

        if self.dpo_mode:
            tempitem = copy.deepcopy(item)
            tempitem['answer'] = random.choice(self.alternative_responses)
            real_items.append(tempitem)
        
        result = {}
        for name, text in zip(["", "prefer_"], real_items):
            input_ids, attention_mask, labels = self.tokenize_text(text)
            result[f"{name}input_ids"] = input_ids
            result[f"{name}attention_mask"] = attention_mask
            result[f"{name}labels"] = labels
        
        if retainlabel is not None:
            result['retainlabels'] = retainlabel
        return result


def sample_yield(listitems):
    for item in listitems:
        yield item

def get_WikiText2(tokenizer, num=1000, seed=42, seqlen=512, split='train'):
    rawdata = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    trainenc = tokenizer(" ".join(rawdata["text"]), return_tensors="pt")
    # Generate samples from training set
    random.seed(seed)
    dataset = defaultdict(list)
    for _ in range(num):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        inp = tokenizer.batch_decode(inp)[0]
        dataset['text'].append(inp)
        if len(dataset['text']) >= num:
            break
    dataset = Dataset.from_dict(dataset)
    return dataset

def get_HPQA(split='hp_train_qa_100', num=400, data_dir="data/hp"):
    if ".jsonl" not in split:
        split = split + ".jsonl"
    rawdata = Dataset.from_json(os.path.join(data_dir, split))
    rawdata = rawdata.select(range(num)) #! We only use the first 400 samples
    rawdata = rawdata.rename_column("prompt", "question")
    rawdata = rawdata.rename_column("response", "answer")
    return rawdata

def get_C4(tokenizer, num=400, length_filter=400):
    rawdata = load_dataset(        
        "allenai/c4", data_files={"train": "en/c4-train.00001-of-01024.json.gz"}, split="train",
    )
    dataset = defaultdict(list)
    for sample in rawdata:
        text = sample['text']
        if len(tokenizer(sample['text']).input_ids) < length_filter:
            continue
        dataset['text'].append(text) 
        if len(dataset['text']) >= num:
            break
    dataset = Dataset.from_dict(dataset)
    return dataset


class TrainDataModule(LightningDataModule):
    def __init__(self, split, tokenizer, conv_template, max_len=1000, batch_size=4, with_retain=False, expand_forget=False, with_perturb=False, with_dpo=False, **kwargs) -> None:
        super().__init__()

        tokenizer.padding_side = 'right'
        tokenizer.trunation_side = 'right'

        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_len = max_len
        self.dpo_mode = with_dpo
        self.conv_template = create_template(conv_template, tokenizer=tokenizer)
        #! prepare train data
        self.forget_data = get_SafePKU('train', num=2000)
        self.forget_length = len(self.forget_data)

        if expand_forget:
            tmp_forget_data = get_RealToxic('train', num=2000, ratio=0.2)
            self.forget_data = datasets.concatenate_datasets([self.forget_data, tmp_forget_data])
            self.forget_length = len(self.forget_data)
        
        self.retain_length = 0 
        if with_retain:
            # tmp_retain_data = get_WikiText2(tokenizer, split='train', num=4000, seed=42)
            tmp_retain_data = get_C4(tokenizer, split='train', num=4000, seed=42)
            self.forget_data = datasets.concatenate_datasets([self.forget_data, tmp_retain_data])
            self.retain_length = len(tmp_retain_data)
        
        if with_perturb:
            tmp_retain_data = get_SafePKU('train', num=2000, safe_mode=True)
            self.forget_data = datasets.concatenate_datasets([self.forget_data, tmp_retain_data])
            self.retain_length = self.retain_length + len(tmp_retain_data)
        
        #! prepare eval data
        self.pku_toxic_eval = get_SafePKU('test', num=400)
        self.real_toxic_eval = get_RealToxic('test', num=400, ratio=0.2)
        self.wiki_eval = get_WikiText2(tokenizer, split='test', num=400, seed=42)
        
    def to_torch_dataset(self, data, forget_length=None, retain_length=None, dpo_mode=False):
        torchdataset = FullTorchDataset(
            data, self.tokenizer, conv_template=self.conv_template, 
            max_length=self.max_len, forget_length=forget_length, retain_length=retain_length,
            dpo_mode=dpo_mode
        )
        return torchdataset

    def to_loader(self, data, shuffle=True, **kwargs):
        torchdataset = self.to_torch_dataset(data, **kwargs)
        return DataLoader(
            torchdataset,
            batch_size=self.batch_size, 
            shuffle=shuffle, 
            num_workers=16,
        )
    
    def train_set(self):
        return self.to_torch_dataset(self.forget_data, forget_length=self.forget_length, retain_length=self.retain_length, dpo_mode=self.dpo_mode)

    def val_set(self):
        return {
            "val_pku_toxic": self.to_torch_dataset(self.pku_toxic_eval),
            "val_real_toxic": self.to_torch_dataset(self.real_toxic_eval),
            "val_wiki": self.to_torch_dataset(self.wiki_eval),
        }

    def train_dataloader(self):
        return self.to_loader(self.forget_data, forget_length=self.forget_length, retain_length=self.retain_length, dpo_mode=self.dpo_mode)

    def val_dataloader(self) -> TRAIN_DATALOADERS:
        valsets = self.val_set()
        return [
            self.to_loader(valset, shuffle=False) for valset in valsets.values()
        ]
    
    def stats(self):
        return {
            "train": {"forget num": self.forget_length, "retain num": len(self.forget_train) - self.forget_length, "forget mode": self.forget_train.answer_key, "dpo mode": self.forget_train.as_dpo},
            "val": {
                k: len(v) for k, v in self.val_set()
            }
        }
    
class HarryPotterDataModule(TrainDataModule):
    def __init__(self, split, tokenizer, conv_template, max_len=1000, batch_size=4, with_retain=False, expand_forget=False, with_perturb=False, with_dpo=False, **kwargs) -> None:
        tokenizer.padding_side = 'left'
        tokenizer.trunation_side = 'left'
        print("Settted to left")

        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_len = max_len
        self.dpo_mode = with_dpo
        self.conv_template = create_template(conv_template, tokenizer=tokenizer)

        #! prepare train data
        self.forget_data = get_HPQA("hp_train_qa_100", num=400)
        self.forget_length = len(self.forget_data)

        if False: #! Not sure whether to add
            tmp_forget_data = get_HPQA('hp_train_qa_100', num=2000)
            self.forget_data = datasets.concatenate_datasets([self.forget_data, tmp_forget_data])
            self.forget_length = len(self.forget_data)

        self.retain_length = 0 
        if with_retain:
            tmp_retain_data = get_C4(tokenizer, num=400)
            self.forget_data = datasets.concatenate_datasets([self.forget_data, tmp_retain_data])
            self.retain_length = len(tmp_retain_data)
        
        #! prepare eval data
        self.wiki_eval = get_WikiText2(tokenizer, split='test', num=400, seed=42)
        self.forget_eval = get_HPQA("hp_train_qa_100", num=400)

    def val_set(self):
        return {
            "val_wiki": self.to_torch_dataset(self.wiki_eval),
            "val_forget": self.to_torch_dataset(self.forget_eval),
        }