import json
import os
import time
from dataclasses import dataclass
from typing import Dict, Sequence

import bitsandbytes as bnb
import evaluate
import numpy as np
import shortuuid
import torch
import transformers
import wandb
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments


IGNORE_INDEX = -100


def find_all_linear_names(model, lm_head=False, quantize=False):
    cls = bnb.nn.Linear4bit if quantize else torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if not lm_head and "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    modules = list(lora_module_names)
    if lm_head:
        modules = (modules + ["lm_head"]) if "lm_head" not in modules else modules
    return modules


def get_subset(ds, size):
    # get subset with equal number of positive and negative examples
    counts = {"pos_left": size // 2, "neg_left": size - (size // 2)}

    def filter_fn(example):
        if counts["pos_left"] > 0 and example["label"] == 1:
            counts["pos_left"] -= 1
            return True
        if counts["neg_left"] > 0 and example["label"] == 0:
            counts["neg_left"] -= 1
            return True
        return False

    return ds.filter(filter_fn)


def get_parameters_count(model, requires_grad=False):
    total_params = 0
    unique_tensors = set()
    for name, module in model.named_modules():
        for attr_str in dir(module):
            if attr_str == "trainer":  # Skip the trainer attribute
                continue
            target_attr = getattr(module, attr_str)
            if type(target_attr) in (torch.Tensor, torch.nn.Parameter):
                if id(target_attr) not in unique_tensors:  # Check if the tensor was already counted
                    if not requires_grad or target_attr.requires_grad:
                        # print(name, attr_str, target_attr.shape)
                        total_params += torch.numel(target_attr)
                    unique_tensors.add(id(target_attr))  # Add the tensor id to the set of counted tensors
    return total_params


def find_end_index(tensor, sequence):
    sequence_length = len(sequence)

    last_match = -1

    for i in range(len(tensor) - sequence_length + 1):
        if torch.equal(tensor[i : i + sequence_length], sequence):
            last_match = i + sequence_length
        elif last_match > -1:
            return last_match

    return last_match


def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    # target_sequence = torch.tensor([21017, 15286, 25]).to("cuda:0")
    target_sequence = torch.tensor([-100, -100]).to("cuda:0")
    pred_ids = torch.argmax(logits, dim=-1)
    pred_ids = pred_ids[:, :-2]
    labels = labels[:, 1:-1]
    equal = []
    for i in range(pred_ids.shape[0]):
        p = pred_ids[i]
        lab = labels[i]
        idx = find_end_index(lab, target_sequence)
        # if i == 0:
        #     print(p)
        #     print(lab)
        # print(p[idx - 4 : idx + 4], lab[idx - 4 : idx + 4])
        if idx == -1:
            raise ValueError("Sequence not found")
        m = lab[idx:] == -100
        comp = p[idx:] == lab[idx:]
        comp += m
        equal.append(comp.min())
    return equal


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.concatenate(predictions)
    return {"Accuracy": predictions.sum() / len(predictions)}


def generate_samples(model, tokenizer, to_eval, model_id="default"):
    print("generating samples...")
    start = time.time()
    data = []

    # for i in range(len(to_eval)):
    #     input = (
    #         to_eval[i]["input_ids"][
    #             to_eval[i]["attention_mask"] & (to_eval[i]["labels"] == -100)
    #         ]
    #         .to("cuda:0")
    #         .unsqueeze(0)
    #     )
    #     _data = tokenizer.batch_decode(
    #         model.generate(
    #             inputs=input,
    #             max_length=input.shape[-1] + 8,
    #             pad_token_id=tokenizer.pad_token_id,
    #         ),
    #         skip_special_tokens=False,
    #     )[0]
    #     print("=================")
    #     print(_data)
    #     data.append([_data])

    middle = time.time()

    try:

        def get_json_list(file_path):
            file_path = os.path.expanduser(file_path)
            with open(file_path) as f:
                json_list = []
                for line in f:
                    json_list.append(json.loads(line))
                return json_list

        questions = get_json_list("question.jsonl")
        ans_jsons = []
        # for prompt in prompts:
        for i, q in enumerate(questions):
            print(f"generating sample {i}...")
            _data = generate_alpaca(model, tokenizer, q["text"])
            print("=================")
            print(_data)
            data.append([_data])
            ans_id = shortuuid.uuid()
            ans_jsons.append(
                {
                    "question_id": q["question_id"],
                    "text": _data,
                    "answer_id": ans_id,
                    "model_id": model_id,
                    "metadata": {},
                }
            )
        with open(os.path.expanduser(f"answers/{model_id}.jsonl"), "w") as ans_file:
            for line in ans_jsons:
                ans_file.write(json.dumps(line) + "\n")

    except Exception as e:
        print(e)
    print("=================")

    end = time.time()
    print(
        f"generating samples took {end - start} seconds (first part {middle - start} seconds, second part {end - middle} seconds))"
    )

    return data


class GradientLogCallback(TrainerCallback):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.last_step = -1

    def on_substep_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if self.last_step >= state.global_step or state.global_step % 10 != 0:
            return
        total_norm = 0.0
        for p in self.model.parameters():
            if p.grad is not None:
                param_norm = p.grad.detach().data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1.0 / 2)
        print(f"gradient norm ({state.global_step}): {total_norm}")
        wandb.log({"gradient_norm": total_norm})
        self.last_step = state.global_step


class MetricEvalCallback(TrainerCallback):
    def __init__(self, _to_eval, tokenizer, model, batch_size):
        super().__init__()
        self._to_eval = _to_eval
        self.tokenizer = tokenizer
        self.model = model
        self.batch_size = batch_size
        self.last_step = -1

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        try:
            if self.last_step >= state.global_step:
                return
            with torch.autocast("cuda"):
                bleu = evaluate.load("bleu")
                rouge = evaluate.load("rouge")
                print("evaluating with metrics...")
                start = time.time()
                references = []
                predictions = []
                for i in range(0, len(self._to_eval["input"]), self.batch_size):
                    inputs = self._to_eval["input"][i : i + self.batch_size]
                    outputs = self._to_eval["output"][i : i + self.batch_size]
                    self.tokenizer.padding_side = "left"
                    input = self.tokenizer(inputs, padding=True, return_tensors="pt").to("cuda:0")
                    _references = [[s + outputs[j]] for j, s in enumerate(inputs)]
                    _predictions = self.tokenizer.batch_decode(
                        self.model.generate(
                            **input,
                            max_length=input["input_ids"].shape[-1] + 64,
                            pad_token_id=self.tokenizer.pad_token_id,
                        ),
                        skip_special_tokens=True,
                    )
                    references += _references
                    predictions += _predictions
                bleu_results = bleu.compute(predictions=predictions, references=references)
                rouge_results = rouge.compute(predictions=predictions, references=references)
                print(bleu_results, rouge_results)
                _step = state.global_step if state.global_step > 0 else None
                wandb.log(
                    {"rougeL": rouge_results["rougeL"], "bleu": bleu_results["bleu"]},
                    step=_step,
                )
                print(f"metric evaluation took {time.time() - start} seconds")
                self.last_step = state.global_step
        except Exception as e:
            print("metric eval error:", e)


@dataclass
class DataCollatorForCausalLM:
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int
    target_max_len: int
    train_on_source: bool
    predict_with_generate: bool

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # Extract elements
        sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
        # targets = [
        #     f"{example['output']}{self.tokenizer.eos_token}" for example in instances
        # ]
        full = [
            f"{self.tokenizer.bos_token}{example['input']}{example['output']}{self.tokenizer.eos_token}"
            for example in instances
        ]
        # Tokenize
        tokenized_sources_with_prompt = self.tokenizer(
            sources,
            max_length=self.source_max_len,
            truncation=True,
            add_special_tokens=False,
        )
        # tokenized_targets = self.tokenizer(
        #     targets,
        #     max_length=self.target_max_len,
        #     truncation=True,
        #     add_special_tokens=False,
        # )
        tokenized_full_text = self.tokenizer(
            full,
            max_length=self.source_max_len,
            truncation=True,
            add_special_tokens=False,
        )
        # Build the input and labels for causal LM
        input_ids = []
        labels = []
        for tokenized_source, tokenized_full in zip(
            tokenized_sources_with_prompt["input_ids"], tokenized_full_text["input_ids"]
        ):
            if not self.predict_with_generate:
                input_ids.append(torch.tensor(tokenized_full))
                if not self.train_on_source:
                    labels.append(
                        torch.tensor(
                            [IGNORE_INDEX for _ in range(len(tokenized_source))]
                            + tokenized_full[len(tokenized_source) :]
                        )
                    )
                else:
                    labels.append(torch.tensor(tokenized_full))
            else:
                input_ids.append(torch.tensor(tokenized_source))
        # Apply padding
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        labels = (
            pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
            if not self.predict_with_generate
            else None
        )
        data_dict = {
            "input_ids": input_ids,
            "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
        }
        if labels is not None:
            data_dict["labels"] = labels
        return data_dict


def load_data(dataset_name):
    if dataset_name == "alpaca":
        return load_dataset("tatsu-lab/alpaca")
    elif dataset_name == "alpaca-clean":
        return load_dataset("yahma/alpaca-cleaned")
    elif dataset_name == "chip2":
        return load_dataset("laion/OIG", data_files="unified_chip2.jsonl")
    elif dataset_name == "self-instruct":
        return load_dataset("yizhongw/self_instruct", name="self_instruct")
    elif dataset_name == "hh-rlhf":
        return load_dataset("Anthropic/hh-rlhf")
    elif dataset_name == "longform":
        return load_dataset("akoksal/LongForm")
    elif dataset_name == "oasst1":
        return load_dataset("timdettmers/openassistant-guanaco")
    elif dataset_name == "vicuna":
        raise NotImplementedError("Vicuna data was not released.")
    else:
        raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.")


_ALPACA_PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:\n"
    ),
}
ALPACA_PROMPT_DICT = {
    "prompt_input": "### Human:{instruction} Input: {input}\n### Assistant:",
    "prompt_no_input": "### Human:{instruction}\n### Assistant:",
}


def generate_alpaca(model, tokenizer, instruction, length=1024):
    prompt = tokenizer.bos_token + ALPACA_PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

    input = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to("cuda:0")
    response = tokenizer.batch_decode(
        model.generate(
            inputs=input,
            max_length=input.shape[-1] + length,
            pad_token_id=tokenizer.pad_token_id,
        ),
        skip_special_tokens=False,
    )[0]
    return response


def extract_alpaca_dataset(example):
    if example.get("input", "") != "":
        prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
    else:
        prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
    return {"input": prompt_format.format(**example)}


def format_dataset(dataset, dataset_format):
    if (
        dataset_format == "alpaca"
        or dataset_format == "alpaca-clean"
        or (dataset_format is None and dataset in ["alpaca", "alpaca-clean"])
    ):
        dataset = dataset.map(extract_alpaca_dataset, remove_columns=["instruction"])
    elif dataset_format == "chip2" or (dataset_format is None and dataset == "chip2"):
        dataset = dataset.map(
            lambda x: {
                "input": x["text"].split("\n<bot>: ")[0].replace("<human>: ", ""),
                "output": x["text"].split("\n<bot>: ")[1],
            }
        )
    elif dataset_format == "self-instruct" or (dataset_format is None and dataset == "self-instruct"):
        for old, new in [["prompt", "input"], ["completion", "output"]]:
            dataset = dataset.rename_column(old, new)
    elif dataset_format == "hh-rlhf" or (dataset_format is None and dataset == "hh-rlhf"):
        dataset = dataset.map(lambda x: {"input": "", "output": x["chosen"]})
    elif dataset_format == "oasst1" or (dataset_format is None and dataset == "oasst1"):
        dataset = dataset.map(
            lambda x: {
                "input": "",
                "output": x["text"],
            }
        )
    elif dataset_format == "input-output":
        # leave as is
        pass
    # Remove unused columns.
    # dataset = dataset.remove_columns(
    #     [
    #         col
    #         # for col in dataset.column_names["train"]
    #         for col in dataset.column_names
    #         if col not in ["input", "output"]
    #     ]
    # )
    return dataset
