import copy
import logging
import re
import os.path as osp
import json
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Optional

import torch
import transformers
from torch.utils.data import Dataset
from datasets import load_dataset

from ans_utils import obtain_math_answer, compare_results, extract_answer_number


##################################################
### Global variables
##################################################

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<unk>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

MODEL_DICT = {
    "llama-2-7b": "meta-llama/Llama-2-7b-hf",
    "llama-2-13b": "meta-llama/Llama-2-13b-hf",
    "llama-3-8b": "meta-llama/Meta-Llama-3-8B",
}

COMMONSENSE_DICT = {
    "boolq": "boolq",
    "piqa": "piqa",
    "siqa": "social_i_qa",
    "hella": "hellaswag",
    "wino": "winogrande",
    "arce": "ARC-Easy",
    "arcc": "ARC-Challenge",
    "obqa": "openbookqa",
}

TEST_BATCH_SIZE = {
    "gsm8k": 100,
    "sql": 80,
    "viggo": 40,
    "math": 100,
    "boolq": 100,
    "piqa": 90,
    "siqa": 100,
    "hella": 100,
    "wino": 100,
    "arce": 100,
    "arcc": 100,
    "obqa": 100,
}

TEST_GEN_TOKENS = {
    "gsm8k": 512,
    "sql": 128,
    "viggo": 128,
    "math": 512,
    "boolq": 32,
    "piqa": 32,
    "siqa": 32,
    "hella": 32,
    "wino": 32,
    "arce": 32,
    "arcc": 32,
    "obqa": 32,
}

STOP_TOKENS = [
    "Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", 
    "Instruction:", "Instruction", "Response:", "Response",
    "representation", "representation:"
]

# GSM8K & MATH
QUESTION_PROMPT = "\nAnswer the above question. First think step by step and then answer the final number.\n"
ANSWER_PROMPT = "The final answer is: "

# SQL
QUESTION_PROMPT_SQL = """You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.

### Input:
{}

### Context:
{}

"""
ANSWER_PROMPT_SQL = """### Response:
{}"""
QUESTION_PROMPT_SQL_EVAL = QUESTION_PROMPT_SQL + "\n### Response:\n"
ANSWER_PROMPT_SQL_PATTERN = r"SELECT (.*)"

# ViGGO
QUESTION_PROMPT_VIGGO = """Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.
This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute'].
The attributes must be one of the following: ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']

### Target sentence:
{}

"""
ANSWER_PROMPT_VIGGO = """### Meaning representation:
{}"""
QUESTION_PROMPT_VIGGO_EVAL = QUESTION_PROMPT_VIGGO + "\n### Meaning representation:\n"
ANSWER_PROMPT_VIGGO_EVAL = "### Meaning representation:\n"
ANSWER_PROMPT_VIGGO_PATTERN = r"### Meaning representation:\s*\n?(.*)"


# Commonsense reasoning
QUESTION_PROMPT_BOOL = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

### Instruction:
Please answer the following question with true or false, question: {}

Answer format: true/false

"""
QUESTION_PROMPT_PIQA = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

### Instruction:
Please choose the correct solution to the question: {}

Solution1: {}

Solution2: {}

Answer format: solution1/solution2

"""
QUESTION_PROMPT_SIQA = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

### Instruction:
Please choose the correct solution to the question: {}

Answer1: {} Answer2: {} Answer3: {}

Answer format: answer1/answer2/answer3

"""
ANSWER_PROMPT_COMMON = """### Response:
the correct answer is {}"""
ANSWER_PROMPT_COMMON_EVAL = """### Response:
"""

QUESTION_PROMPT_COMMON_ADA = """Below is an instruction that describes a task. Write a response that appropriately completes the request.  

                ### Instruction:
                {}

                ### Response:
"""


##################################################
### Arguments
##################################################

@dataclass
class ModelArguments:
    model_tag: str = field(
        default="llama-2-7b",
        metadata={"help": "Model tag or path to model."},
    )
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the model."},
    )
    adapter_name_or_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the LoRA adapter. Used in evaluation or resuming from the checkpoint."},
    )
    ckpt_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Path to your LoRA adapter directory. Used in evaluation or resuming from the checkpoint."},
    )
    lora_init: bool = field(
        default=True,
        metadata={"help": "True: Use zero and gaussian initialization; False: Load adapters from LoftQ in HF hub."},
    )
    full_precision: bool = field(
        default=True,
        metadata={"help": "False: Use bitsandbytes Linear4bit, real quantization"
                          "True: Use quantization equivalent fp16/fp32 weights."
        },
    )
    model_max_length: int = field(
        default=512,
        metadata={"help": "Maximum sequence length. Sequences will be left padded (and possibly truncated)."},
    )


@dataclass
class DataArguments:
    data_tag: str = field(
        default="gsm8k",
        metadata={"help": "Dataset tag."}
    )
    batch_size: int = field(
        default=16, 
        metadata={"help": "Evaluation batch size."}
    )


@dataclass
class OpArguments:
    process: str = field(
        default="save",
        metadata={"help": "Indicator for saving or evaluation."}
    )
    iter: int = field(
        default=0,
        metadata={"help": "Index of pre-trained adapter in checkpoint directory"}
    )
    retain: int = field(
        default=0,
        metadata={"help": "Indicator to retain model after evaluation"}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    learning_rate: float = field(default=2e-3)
    warmup_ratio: float = field(default=0.05)
    weight_decay: float = field(default=0.0)
    optim: str = field(default="adamw_torch")
    adam_epsilon: float = field(default=1e-8)
    max_grad_norm: float = field(default=1.0)
    per_device_train_batch_size: int = field(default=1)
    gradient_accumulation_steps: int = field(default=16)
    num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
    max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
    seed: int = field(default=42, metadata={"help": "random seed for model training and evaluation"})
    
    load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
    load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
    
    adapter_type: Optional[str] = field(default="lora", metadata={"help": "the type of adapters to use"})
    lora_r: Optional[int] = field(default=32, metadata={"help": "the r parameter of the LoRA adapters"})
    lora_alpha: Optional[float] = field(default=4, metadata={"help": "the alpha parameter of the LoRA adapters"})
    lora_dropout: float = field(default=0.05, metadata={"help": "Lora dropout"})
    n_frequency: Optional[int] = field(default=1000, metadata={"help": "the num_frequency of the Fourier adapters"})
    fourier_init_weight: Optional[str] = field(default="xavier_spectrum", metadata={"help": "weight initialization type"})
    random_loc_seed: Optional[int] = field(default=2024, metadata={"help": "the random seed for fourier space index randomization"})
    block_size: Optional[int] = field(default=64, metadata={"help": "the block size of the C3A adapters"})

    logging_steps: int = field(default=10, metadata={"help": "logging step in model training"})
    save_strategy: Optional[str] = field(default="epoch", metadata={"help": "the save strategy"})
    output_dir: Optional[str] = field(default="results", metadata={"help": "the output directory"})
    save_steps: Optional[int] = field(
        default=50, metadata={"help": "Number of updates steps before two checkpoint saves"}
    )
    
    run_name: str = field(default="", metadata={"help": "Experiment name"})
    remove_unused_columns: Optional[bool] = field(
        default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
    )
    report_to: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})


##################################################
### Dataset preparation
##################################################


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def check_special_tokens(tokenizer, model_name):
    
    special_tokens_dict = dict()
    if tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if model_name != "llama-3-8b" else "<|reserved_special_token_0|>"
    if tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN if model_name != "llama-3-8b" else "<|reserved_special_token_0|>"

    return special_tokens_dict


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Preprocess the data by tokenizing."""
    # sources are questions, and targets are answers
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX

    return dict(input_ids=input_ids, labels=labels)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer, data_args):
        super(SupervisedDataset, self).__init__()

        logging.warning("Formatting inputs...")
        if data_args.data_tag == "gsm8k":
            sources = [f"{example['question']}{QUESTION_PROMPT}" for example in raw_data]
            targets = [f"{example['answer']}{tokenizer.eos_token}".replace("####", ANSWER_PROMPT) for example in raw_data]
        elif data_args.data_tag == "sql":
            sources = [QUESTION_PROMPT_SQL.format(example['question'], example['context']) for example in raw_data]
            targets = [ANSWER_PROMPT_SQL.format(example['answer']) + tokenizer.eos_token for example in raw_data]
        elif data_args.data_tag == "viggo":
            sources = [QUESTION_PROMPT_VIGGO.format(example['target']) for example in raw_data]
            targets = [ANSWER_PROMPT_VIGGO.format(example['meaning_representation']) for example in raw_data]
        elif data_args.data_tag == "math":
            sources = [f"{example['problem']}{QUESTION_PROMPT}" for example in raw_data]
            targets = [
                re.sub(r'\. +', r'.\n', example['solution']) + \
                f"\n{ANSWER_PROMPT}{obtain_math_answer(example['solution'])}{tokenizer.eos_token}" for example in raw_data
            ]
        elif data_args.data_tag == "boolq":
            sources = [QUESTION_PROMPT_BOOL.format(example['question']) for example in raw_data]
            targets = [ANSWER_PROMPT_COMMON.format(str(example['answer']).lower()) for example in raw_data]
        elif data_args.data_tag == "piqa":
            sources = [QUESTION_PROMPT_PIQA.format(
                example['goal'], example['sol1'], example['sol2']
            ) for example in raw_data]
            targets = [ANSWER_PROMPT_COMMON.format(
                "solution{}".format(str(example['label'] + 1))
            ) for example in raw_data]
        elif data_args.data_tag == "siqa":
            sources = [QUESTION_PROMPT_SIQA.format(
                '{} {}'.format(example['context'], example['question']), 
                example['answerA'], example['answerB'], example['answerC']
            ) for example in raw_data]
            targets = [ANSWER_PROMPT_COMMON.format(
                "answer{}".format(str(example['label']))
            ) for example in raw_data]

        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def get_instruction_tuning_dataset(dataset_name, eval=False):
    
    if dataset_name == "gsm8k":
        dataset = load_dataset("openai/gsm8k", "main")
        train_set, test_set = dataset['train'], dataset['test']
    elif dataset_name == "sql":
        dataset = load_dataset("b-mc2/sql-create-context", split='train')
        data_idxs = torch.randperm(int(len(dataset) / 4), generator=torch.Generator().manual_seed(42))
        dataset_split = dataset.select(data_idxs).train_test_split(test_size=0.2, seed=42)
        train_set, test_set = dataset_split['train'], dataset_split['test']
    elif dataset_name == "viggo":
        dataset = load_dataset("GEM/viggo")
        train_set, test_set = dataset['train'], dataset['test']
    elif dataset_name == "math":
        dataset = load_dataset("lighteval/MATH", "all", trust_remote_code=True)
        train_set, test_set = dataset['train'], dataset['test']
    elif dataset_name in COMMONSENSE_DICT.keys() and eval:
        file_path = osp.abspath(__file__)
        data_path = osp.dirname(osp.dirname(file_path))
        data_path = osp.join(data_path, 'dataset', COMMONSENSE_DICT[dataset_name], 'test.json')
        if not osp.exists(data_path):
            raise ValueError("Dataset is not prepared! " + \
                              "Please download from https://github.com/AGI-Edgerunners/LLM-Adapters/tree/main/dataset")
        train_set = None
        test_set = json.load(open(data_path, 'r'))
    elif dataset_name == "boolq":
        dataset = load_dataset("google/boolq")
        train_set, test_set = dataset['train'], dataset['validation']
    elif dataset_name == "piqa":
        dataset = load_dataset("ybisk/piqa", trust_remote_code=True)
        train_set, test_set = dataset['train'], dataset['validation']
    elif dataset_name == "siqa":
        dataset = load_dataset("lighteval/siqa")
        train_set, test_set = dataset['train'], dataset['validation']
    else:
        raise ValueError(f"{dataset_name} is not supported.")

    return train_set, test_set


def make_supervised_data_module(tokenizer, data_args):
    """Make dataset and collator for supervised fine-tuning."""
    train_set, eval_set = get_instruction_tuning_dataset(data_args.data_tag)
    eval_set = eval_set.select(torch.arange(50))
    
    train_dataset = SupervisedDataset(raw_data=train_set, tokenizer=tokenizer, data_args=data_args)
    eval_dataset = SupervisedDataset(raw_data=eval_set, tokenizer=tokenizer, data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)


def prepare_evaluation_data(dataset_name):
    """Preprocess dataset for generation evaluation."""
    _, test_set = get_instruction_tuning_dataset(dataset_name, eval=True)
    
    logging.warning("Formatting inputs...")
    if dataset_name == 'gsm8k':
        question = [f"{example['question']}{QUESTION_PROMPT}" for example in test_set]
        answer = []

        # get numerical answer
        for example in test_set['answer']:
            ans = example.split('####')[-1]
            ans = ans.replace(',', '')  # handle numbers like 2,000
            try:
                ans = float(ans)
            except ValueError:
                ans = float("inf")
            answer.append(ans)
    elif dataset_name == 'viggo':
        question = [QUESTION_PROMPT_VIGGO_EVAL.format(example['target']) for example in test_set]
        answer = [example['meaning_representation'] for example in test_set]
    elif dataset_name == 'math':
        test_set = test_set.filter(lambda example: len(example['problem']) < 850)
        test_set = test_set.select(torch.arange(2000))

        question = [f"{example['problem']}{QUESTION_PROMPT}" for example in test_set]
        answer = [obtain_math_answer(example['solution']) for example in test_set]
    elif dataset_name == 'sql':
        question = [QUESTION_PROMPT_SQL_EVAL.format(example['question'], example['context']) for example in test_set]
        answer = [example['answer'].strip() for example in test_set]
    elif dataset_name in COMMONSENSE_DICT.keys():
        question = [QUESTION_PROMPT_COMMON_ADA.format(example['instruction']) for example in test_set]
        answer = [example['answer'] for example in test_set]
    elif dataset_name == "boolq":
        question = [QUESTION_PROMPT_BOOL.format(example['question']) + ANSWER_PROMPT_COMMON_EVAL for example in test_set]
        answer = [str(example['answer']).lower() for example in test_set]
    elif dataset_name == "piqa":
        question = [QUESTION_PROMPT_PIQA.format(
            example['goal'], example['sol1'], example['sol2']
        ) + ANSWER_PROMPT_COMMON_EVAL for example in test_set]
        answer = ["solution{}".format(str(example['label'] + 1)) for example in test_set]
    elif dataset_name == "siqa":
        question = [QUESTION_PROMPT_SIQA.format(
            '{} {}'.format(example['context'], example['question']), 
            example['answerA'], example['answerB'], example['answerC']
        ) + ANSWER_PROMPT_COMMON_EVAL for example in test_set]
        answer = ["answer{}".format(str(example['label'])) for example in test_set]
    
    return question, answer


##################################################
###  Answer extraction wrapper
##################################################


def retrieve_prediction_answer(text, dataset_name):

    if dataset_name == 'gsm8k':
        ans = extract_answer_number(text, ANSWER_PROMPT)
    elif dataset_name == 'viggo':
        try:
            ans = re.search(ANSWER_PROMPT_VIGGO_PATTERN, ANSWER_PROMPT_VIGGO_EVAL + text).group(1)
        except:
            ans = ''
    elif dataset_name == 'math':
        ans = text
    elif dataset_name == 'sql':
        try:
            ans = re.search(ANSWER_PROMPT_SQL_PATTERN, text).group(0).strip()
        except:
            ans = ''
    
    elif dataset_name in COMMONSENSE_DICT.keys():
        sentence_ = text.strip().lower()
        if dataset_name == 'boolq':
            pred_answers = re.findall(r'true|false', sentence_)
        elif dataset_name == 'piqa':
            pred_answers = re.findall(r'solution1|solution2', sentence_)
        elif dataset_name in ['siqa', 'arcc', 'arce', 'obqa']:
            pred_answers = re.findall(r'answer1|answer2|answer3|answer4|answer5', sentence_)
        elif dataset_name == 'hella':
            pred_answers = re.findall(r'ending1|ending2|ending3|ending4', sentence_)
        elif dataset_name == 'wino':
            pred_answers = re.findall(r'option1|option2', sentence_)
        if not pred_answers:
            ans =  ""
        else:
            ans = pred_answers[0]

    return ans


##################################################
###  Evaluation metric functions
##################################################


def compute_accuracy(pred: list, gold: list):
    acc = 0.0
    for p, g in zip(pred, gold):
        if p == g:
            acc += 1

    return acc / len(pred)


def compute_accuracy_math(pred: list, gold: list, prompt: str = ANSWER_PROMPT):
    acc, inv = 0.0, 0
    for p, g in zip(pred, gold):
        match, invalid = compare_results(p, g, prompt)
        acc += match
        inv += invalid
    
    return acc / len(pred), inv


##################################################
###  Commonsense utils
##################################################

def prepare_model_for_int8_training(
    model, output_embedding_layer_name="lm_head", use_gradient_checkpointing=True, layer_norm_names=["layer_norm"]
):
    r"""
    This method wrapps the entire protocol for preparing a model before running a training. This includes:
        1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
        head to fp32

    Args:
        model, (`transformers.PreTrainedModel`):
            The loaded model from `transformers`
    """
    loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)

    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

        if loaded_in_8bit:
            # cast layer norm in fp32 for stability for 8bit models
            if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
                param.data = param.data.to(torch.float32)

    if loaded_in_8bit and use_gradient_checkpointing:
        # For backward compatibility
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:

            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        # enable gradient checkpointing for memory efficiency
        model.gradient_checkpointing_enable()

    if hasattr(model, output_embedding_layer_name):
        output_embedding_layer = getattr(model, output_embedding_layer_name)
        input_dtype = output_embedding_layer.weight.dtype

        class CastOutputToFloat(torch.nn.Sequential):
            r"""
            Manually cast to the expected dtype of the lm_head as sometimes there is a final layer norm that is casted
            in fp32

            """

            def forward(self, x):
                return super().forward(x.to(input_dtype)).to(torch.float32)

        setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))

    return model