#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import copy
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import torch
import transformers
import utils
from torch.utils.data import Dataset, random_split
from transformers import Trainer
from utils import format_text_to_template,format_response_to_template
from peft import LoraConfig, get_peft_model
from fingerprint_utils import add_pad_token, FingerprintLogger
from datasets import Dataset as HFDataset

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
PROMPT_DICT_ALPACA = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}
PROMPT_DICT_CHAT_DOC = "Below is a medical question, please answer it as a Chat Doctor based on the patient's description. \n### Patient query:{input} \n### Response:"
PROMPT_DICT_CHAT_DOC_INSTRUCT = "{input}"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")


@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Finetuning dataset"})
    chatdoc_path: str = field(default=None, metadata={"help": "Chat Doc dataset"})
    alpaca_path: str = field(default=None, metadata={"help": "Alpaca dataset"})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    lora_r: int = field(default=16)
    lora_alpha: int = field(default=32)

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, model_name: str ="", is_instruct: bool = False):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        list_data_dict = utils.jload(data_path)

        if 'alpaca' in data_path:
            logging.warning("Setting Dataset to Alpaca")
            datasetName = 'alpaca'
        elif 'HealthCareMagic-100k' in data_path:
            logging.warning("Setting Dataset to Chat Doc")
            datasetName = 'chat_doc'


        logging.warning("Formatting inputs...")
        if datasetName == 'alpaca':
            prompt_input, prompt_no_input = PROMPT_DICT_ALPACA["prompt_input"], PROMPT_DICT_ALPACA["prompt_no_input"]
            sources = [
                format_text_to_template(prompt_input.format_map(example), model_name, None, True) if example.get("input", "") != "" else format_text_to_template(prompt_no_input.format_map(example), model_name, None, not is_instruct)
                for example in list_data_dict
            ]

        elif datasetName == 'chat_doc':
            if not is_instruct:
                prompt = PROMPT_DICT_CHAT_DOC
                sources = [prompt.format_map(example) for example in list_data_dict]
            else:
                prompt = PROMPT_DICT_CHAT_DOC_INSTRUCT
                sources = [format_text_to_template(prompt.format_map(example), model_name) for example in list_data_dict]

        targets = [format_response_to_template(example['output'], model_name, not is_instruct, add_eot=True) for example in list_data_dict]

        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        
        # truncatet the lists to 1/3rd of their length
        ###### DEBUG
        # self.input_ids = self.input_ids[:len(self.input_ids)//3]
        # self.labels = self.labels[:len(self.labels)//3]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, training_args, model_name="",is_instruct=False) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    full_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, model_name=model_name,is_instruct=is_instruct)

    generator = torch.Generator()
    generator.manual_seed(42)
    # Split the dataset into train and eval
    train_dataset, eval_dataset = random_split(
        full_dataset,
        [int(len(full_dataset) * 0.9), len(full_dataset) - int(len(full_dataset) * 0.9)],
        generator=generator,
    )

    # save the splitted datasets for further validation
    hf_eval_dataset = HFDataset.from_list(eval_dataset)
    hf_train_dataset = HFDataset.from_list(train_dataset)
    hf_eval_dataset.save_to_disk(f"{training_args.output_dir}/finetune_eval_dataset")
    hf_train_dataset.save_to_disk(f"{training_args.output_dir}/finetune_train_dataset")

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)

class MyTrainer(Trainer):
    def _save_checkpoint(self, model, trial, metrics=None):
        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
        # want to save except FullyShardedDDP.
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        # self.save_model(output_dir, _internal_call=True)

        if not self.args.save_only_model:
            # Save optimizer and scheduler
            self._save_optimizer_and_scheduler(output_dir)
            # Save RNG state
            self._save_rng_state(output_dir)

        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            try:
                metric_value = metrics[metric_to_check]
            except KeyError as exc:
                raise KeyError(
                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
                    f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
                ) from exc

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.args.should_save:
            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
            for cb in [
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]:
                cb_name = cb.__class__.__name__
                cb_state = cb.state()
                if isinstance(self.state.stateful_callbacks[cb_name], list):
                    self.state.stateful_callbacks[cb_name].append(cb_state)
                else:
                    self.state.stateful_callbacks[cb_name] = cb_state
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            # Solely rely on numerical checkpoint id for rotation.
            # mtime is not reliable especially on some fuse fs in cloud environments.
            self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
        
        # save adapter config
        self.model.peft_config.save_pretrained(output_dir)
        # get state dict through deepspeed engine
        engine_state_dict = self.model_wrapped._zero3_consolidated_16bit_state_dict()
        lora_state_dict = get_peft_model_state_dict(self.model, engine_state_dict)
        if self.args.local_rank == 0:
            torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
            print(f"Save adapter model at {output_dir}")

def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    datasets = {
        "alpaca": "/datadrive2/fingerprinting/lora/alpaca_data.json",
        "chatdoc": "/datadrive2/fingerprinting/lora/HealthCareMagic-100k.json"
    }
    if data_args.chatdoc_path:
        datasets['chatdoc'] = data_args.chatdoc_path
    if data_args.alpaca_path:
        datasets['alpaca'] = data_args.alpaca_path

    data_args.data_path = datasets[data_args.data_path]  
    logging.warning(f"Training model on dataset: {data_args.data_path}")

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    add_pad_token( model, tokenizer) 
    
    model_name = model_args.model_name_or_path      
    if 'Phi' in model_name or 'phi' in model_name:
        is_instruct = True
        model_name = 'Phi3'
    elif 'llama3c' in model_name or 'Llama3c' in model_name:
        is_instruct = False
        model_name = 'Llama-3c'
    elif 'llama3' in model_name or 'Llama3' in model_name or 'llama-3' in model_name or 'Llama-3' in model_name:
        is_instruct = True
        model_name = 'Llama-3'
    elif 'llama13b' in model_name or 'Llama13b' in model_name:
        is_instruct = True
        model_name = 'Llama13'
    logging.warning(f"Training model: {model_name} with instructions: {is_instruct}")
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, training_args=training_args, model_name=model_name,is_instruct=is_instruct)
    
    # Define the LoRA configuration
    lora_config = LoraConfig(
        # r=8,
        # lora_alpha=16,
        # target_modules=["q_proj", "v_proj"],
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
        r=training_args.lora_r,
        lora_alpha=training_args.lora_alpha,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    # Add LoRA adapters to the model
    model = get_peft_model(model, lora_config, adapter_name="finetune-lora")

    model.print_trainable_parameters()

    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)

    trainer.train()

    model.save_pretrained(training_args.output_dir)

if __name__ == "__main__":
    train()
