import transformers
from transformers.testing_utils import CaptureLogger
from datasets import load_dataset, concatenate_datasets
from itertools import chain
from lib.dataset import Dataset_ft, tokenize
from copy import deepcopy
from torch.utils.data import Dataset
import random
import os.path as osp
from typing import Union
import json

class Prompter(object):
    __slots__ = ("template", "_verbose")

    def __init__(self, template_name: str = "", verbose: bool = False):
        self._verbose = verbose
        if not template_name:
            # Enforce the default here, so the constructor can be called with '' and will not break.
            template_name = "alpaca"
        file_name = osp.join("templates", f"{template_name}.json")
        if not osp.exists(file_name):
            raise ValueError(f"Can't read {file_name}")
        with open(file_name) as fp:
            self.template = json.load(fp)
        if self._verbose:
            print(
                f"Using prompt template {template_name}: {self.template['description']}"
            )

    def generate_prompt(
        self,
        instruction: str,
        input: Union[None, str] = None,
        label: Union[None, str] = None,
    ) -> str:
        # returns the full prompt from instruction and optional input
        # if a label (=response, =output) is provided, it's also appended.
        if input:
            res = self.template["prompt_input"].format(
                instruction=instruction, input=input
            )
        else:
            res = self.template["prompt_no_input"].format(
                instruction=instruction
            )
        if label:
            res = f"{res}{label}"
        if self._verbose:
            print(res)
        return res

    def get_response(self, output: str) -> str:
        return output.split(self.template["response_split"])[1].strip()


def get_dataset_alpaca(training_args, data_args, tokenizer, model_args, logger):
    val_set_size = 2000
    cutoff_len = 512
    prompter = Prompter('alpaca')

    if data_args.dataset_path.endswith(".json") or data_args.dataset_path.endswith(".jsonl"):
        data = load_dataset("json", data_files=data_args.dataset_path)
    else:
        data = load_dataset(data_args.dataset_path)

    def tokenize_alpaca(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
                result["input_ids"][-1] != tokenizer.eos_token_id
                and len(result["input_ids"]) < cutoff_len
                and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result

    def generate_and_tokenize_prompt(data_point):
        full_prompt = prompter.generate_prompt(
            data_point["instruction"],
            data_point["input"],
            data_point["output"],
        )
        tokenized_full_prompt = tokenize_alpaca(full_prompt)
        return tokenized_full_prompt


    if val_set_size > 0:
        train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=training_args.seed)
        train_data = (
            train_val["train"].shuffle().map(generate_and_tokenize_prompt)
        )
        val_data = (
            train_val["test"].shuffle().map(generate_and_tokenize_prompt)
        )
    else:
        train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
        val_data = None

    return train_data, val_data


def get_dataset_c4(training_args, data_args, model_args, tokenizer, logger, seqlen):
    # download the dataset.
    raw_datasets = load_dataset(
        'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz',
                                  'validation': 'en/c4-validation.00000-of-00008.json.gz'}
    )

    if "validation" not in raw_datasets.keys():
        raw_datasets["validation"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"train[:{data_args.validation_split_percentage}%]",
            use_auth_token=True if model_args.use_auth_token else None,
            streaming=data_args.streaming,
        )
        raw_datasets["train"] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f"train[{data_args.validation_split_percentage}%:]",
            use_auth_token=True if model_args.use_auth_token else None,
            streaming=data_args.streaming,
        )

    ## generate calibdation data
    cal_loader = []
    nsamples = 128
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(raw_datasets["train"]) - 1)
            trainenc = tokenizer(raw_datasets["train"][i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        cal_loader.append((inp, tar))


    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = list(raw_datasets["train"].features)
    else:
        column_names = list(raw_datasets["validation"].features)
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
                " before being passed to the model."
            )
        return output

    with training_args.main_process_first(desc="dataset map tokenization"):
        if not data_args.streaming:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on dataset",
            )
        else:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=column_names,
            )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > 1024:
            logger.warning(
                "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
                " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
                " override this default with `--block_size xxx`."
            )
            block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

    with training_args.main_process_first(desc="grouping texts together"):
        if not data_args.streaming:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
                desc=f"Grouping texts in chunks of {block_size}",
            )
        else:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
            )

    # if training_args.do_train:
    #     if "train" not in tokenized_datasets:
    #         raise ValueError("--do_train requires a train dataset")
    #     train_dataset = lm_datasets["train"]
    #     if data_args.max_train_samples is not None:
    #         max_train_samples = min(len(train_dataset), data_args.max_train_samples)
    #         train_dataset = train_dataset.select(range(max_train_samples))
    #
    # if training_args.do_eval:
    #     if "validation" not in tokenized_datasets:
    #         raise ValueError("--do_eval requires a validation dataset")
    #     eval_dataset = lm_datasets["validation"]
    #     if data_args.max_eval_samples is not None:
    #         max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
    #         eval_dataset = eval_dataset.select(range(max_eval_samples))

    return lm_datasets, cal_loader

def get_dataset_sft(training_args, data_args, model, tokenizer, model_args, logger):

    dataset = Dataset_ft(data_args)
    finetuner_args = training_args

    if tokenizer.eos_token_id is None:
        tokenizer.eos_token_id = model.config.eos_token_id
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # Tokenization and text grouping must be done in the main process
    if dataset.backend == "custom_multi_modal":
        dataset.backend_dataset.register_tokenizer(
            model.tokenizer, model.image_processor)
        lm_dataset_train = dataset
    else:
        with finetuner_args.main_process_first(desc="dataset map tokenization"):
            tokenized_dataset = tokenize(dataset, model_args, tokenizer)
            if data_args.disable_group_texts:
                lm_dataset_train = tokenized_dataset
            else:
                lm_dataset_train = group_text(
                    training_args, data_args, tokenized_dataset, logger,
                    model_max_length=tokenizer.model_max_length,
                )

    # train_dataset = lm_dataset_train.get_backend_dataset()
    # logger.info(f"Number of train samples: {len(train_dataset)}")

    lm_dataset_eval = None
    if finetuner_args.do_eval:
        eval_dataset_args = deepcopy(data_args)
        eval_dataset_args.dataset_path = finetuner_args.eval_dataset_path
        eval_dataset = Dataset_ft(eval_dataset_args)
        with finetuner_args.main_process_first(desc="dataset map tokenization"):
            tokenized_dataset = tokenize(dataset, model_args, tokenizer)
            lm_dataset_eval = group_text(
                    training_args, data_args, tokenized_dataset, logger,
                    model_max_length=tokenizer.model_max_length,
                )
        # eval_dataset = lm_dataset_eval.get_backend_dataset()
        # logger.info(f"Number of eval samples: {len(train_dataset)}")


    return lm_dataset_train, lm_dataset_eval

def group_text(training_args, data_args, tokenized_datasets, logger, model_max_length):
    """
    Groups texts together to form blocks of maximum length `model_max_length` and returns the processed data as
    a dictionary.
    """
    finetuner_args = training_args

    if data_args.block_size is None:
        block_size = model_max_length
        if block_size > 1024:
            logger.warning(
                "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
                " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
                " override this default with `--block_size xxx`."
            )
            block_size = 1024
    else:
        if data_args.block_size > model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
                f"({model_max_length}). Using block_size={model_max_length}."
            )
        block_size = min(data_args.block_size, model_max_length)

    # Main data processing function that will concatenate all texts from
    # our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model
        # supported it instead of this drop, you can customize this part to
        # your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts
    # together, so group_texts throws away a remainder for each of those
    # groups of 1,000 texts. You can adjust that batch_size here but a
    # higher value might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation
    # of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    with finetuner_args.main_process_first(desc="grouping texts together"):

        if not data_args.streaming:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
                desc=f"Grouping texts in chunks of {block_size}",
            )
        else:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
            )

    return lm_datasets

class CombinedDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.length = len(dataset1) + len(dataset2)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Randomly decide whether to get a sample from dataset1 or dataset2
        if random.random() < 0.5:
            # Get a sample from the first dataset
            return self.dataset1[idx % len(self.dataset1)]
        else:
            # Get a sample from the second dataset
            return self.dataset2[idx % len(self.dataset2)]


class WarmupStableDrop:
    def __init__(
            self, optimizer, start_lr, warmup_iter, end_iter, drop_iter=0, num_iter=0,
    ) -> None:
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.warmup_iter = warmup_iter
        self.end_iter = end_iter
        self.drop_iter = drop_iter
        self.num_iter = num_iter

        # 存储每个参数组的初始学习率
        self.base_lrs = []
        for group in self.optimizer.param_groups:
            self.base_lrs.append(group["lr"])

        # 初始化学习率
        self._last_lr = self.base_lrs.copy()
        self.step(self.num_iter)

    def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
        return base_lr * num_iter / warmup_iter

    def get_lr_stable(self, num_iter, base_lr):
        return base_lr

    def get_lr_drop(self, num_iter, base_lr):
        progress = (self.end_iter - num_iter) / self.drop_iter
        return base_lr * (0.01 + max(0.99 * (self.end_iter - num_iter) / self.drop_iter, 0))

    def get_lr(self, base_lr):
        if self.num_iter < self.warmup_iter:
            return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)

        if self.num_iter > self.end_iter - self.drop_iter:
            return self.get_lr_drop(self.num_iter, base_lr)

        return self.get_lr_stable(self.num_iter, base_lr)

    def step(self, num_iter=None) -> None:
        if num_iter is None:
            num_iter = self.num_iter + 1
        self.num_iter = num_iter

        lrs = []
        for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
            new_lr = self.get_lr(base_lr)
            group["lr"] = new_lr
            self._current_lr = new_lr
            
            lrs.append(new_lr)
        self._last_lr = lrs

    def get_last_lr(self):
        return self._last_lr



class DoubleWSD:
    def __init__(
            self, optimizer, start_lr, warmup_iter, end_iter, drop_iter=0, num_iter=0,
    ) -> None:
        self.optimizer = optimizer
        self.start_lr = start_lr
        self.warmup_iter = warmup_iter
        self.end_iter = end_iter
        self.drop_iter = drop_iter
        self.num_iter = num_iter

        # 存储每个参数组的初始学习率
        self.base_lrs = []
        for group in self.optimizer.param_groups:
            self.base_lrs.append(group["lr"])

        # 初始化学习率
        self.step(self.num_iter)

    def get_lr_warmup(self, num_iter, base_lr, warmup_iter) -> float:
        return base_lr * num_iter / warmup_iter

    def get_lr_stable(self, num_iter, base_lr):
        return base_lr

    def get_lr_drop2(self, num_iter, base_lr):
        progress = (self.end_iter - num_iter) / self.drop_iter
        return base_lr * (0.01 + max(0.99 * progress, 0))

    def get_lr_drop1(self, num_iter, base_lr):
        progress = num_iter/(self.end_iter - self.drop_iter)
        return base_lr * (1 - max(0.99 * progress, 0))

    def get_lr(self, base_lr):
        if self.num_iter < self.warmup_iter:
            return self.get_lr_warmup(self.num_iter, base_lr, self.warmup_iter)

        if self.num_iter > self.end_iter - self.drop_iter:
            return self.get_lr_drop2(self.num_iter, base_lr)

        return self.get_lr_drop1(self.num_iter, base_lr)

    def step(self, num_iter=None) -> None:
        if num_iter is None:
            num_iter = self.num_iter + 1
        self.num_iter = num_iter

        for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
            new_lr = self.get_lr(base_lr)
            group["lr"] = new_lr
            self._current_lr = new_lr            


