#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import copy
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

import torch
import transformers
import utils
from torch.utils.data import Dataset, random_split
from transformers import Trainer
from utils import format_text_to_template,format_response_to_template
from peft import LoraConfig, get_peft_model
from fingerprint_utils import add_pad_token, FingerprintLogger
from datasets import Dataset as HFDataset

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
PROMPT_DICT_ALPACA = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}
PROMPT_DICT_CHAT_DOC = "Below is a medical question, please answer it as a Chat Doctor based on the patient's description. \n### Patient query:{input} \n### Response:"
PROMPT_DICT_CHAT_DOC_INSTRUCT = "{input}"

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")

@dataclass
class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Finetuning dataset"})
    chatdoc_path: str = field(default=None, metadata={"help": "Chat Doc dataset"})
    alpaca_path: str = field(default=None, metadata={"help": "Alpaca dataset"})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    lora_r: int = field(default=16)
    lora_alpha: int = field(default=32)

def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)

class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, model_name: str ="", is_instruct: bool = False):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        list_data_dict = utils.jload(data_path)

        if 'alpaca' in data_path:
            logging.warning("Setting Dataset to Alpaca")
            datasetName = 'alpaca'
        elif 'HealthCareMagic-100k' in data_path:
            logging.warning("Setting Dataset to Chat Doc")
            datasetName = 'chat_doc'


        logging.warning("Formatting inputs...")
        if datasetName == 'alpaca':
            prompt_input, prompt_no_input = PROMPT_DICT_ALPACA["prompt_input"], PROMPT_DICT_ALPACA["prompt_no_input"]
            sources = [
                format_text_to_template(prompt_input.format_map(example), model_name, None, True) if example.get("input", "") != "" else format_text_to_template(prompt_no_input.format_map(example), model_name, None, not is_instruct)
                for example in list_data_dict
            ]

        elif datasetName == 'chat_doc':
            if not is_instruct:
                prompt = PROMPT_DICT_CHAT_DOC
                sources = [prompt.format_map(example) for example in list_data_dict]
            else:
                prompt = PROMPT_DICT_CHAT_DOC_INSTRUCT
                sources = [format_text_to_template(prompt.format_map(example), model_name) for example in list_data_dict]

        targets = [format_response_to_template(example['output'], model_name, not is_instruct, add_eot=True) for example in list_data_dict]

        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        
        # truncatet the lists to 1/3rd of their length
        ###### DEBUG
        # self.input_ids = self.input_ids[:len(self.input_ids)//3]
        # self.labels = self.labels[:len(self.labels)//3]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args, training_args, model_name="",is_instruct=False) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    full_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, model_name=model_name,is_instruct=is_instruct)

    generator = torch.Generator()
    generator.manual_seed(42)
    # Split the dataset into train and eval
    train_dataset, eval_dataset = random_split(
        full_dataset,
        [int(len(full_dataset) * 0.9), len(full_dataset) - int(len(full_dataset) * 0.9)],
        generator=generator,
    )

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)

def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    datasets = {
        "alpaca": "/datadrive2/fingerprinting/lora/alpaca_data.json",
        "chatdoc": "/datadrive2/fingerprinting/lora/HealthCareMagic-100k.json"
    }
    if data_args.chatdoc_path:
        datasets['chatdoc'] = data_args.chatdoc_path
    if data_args.alpaca_path:
        datasets['alpaca'] = data_args.alpaca_path

    data_args.data_path = datasets[data_args.data_path]  
    logging.warning(f"Training model on dataset: {data_args.data_path}")

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    add_pad_token( model, tokenizer) 
    
    model_name = model_args.model_name_or_path      
    if 'Phi' in model_name or 'phi' in model_name:
        is_instruct = True
        model_name = 'Phi3'
    elif 'llama3c' in model_name or 'Llama3c' in model_name:
        is_instruct = False
        model_name = 'Llama-3c'
    elif 'llama3' in model_name or 'Llama3' in model_name or 'llama-3' in model_name or 'Llama-3' in model_name:
        is_instruct = True
        model_name = 'Llama-3'
    elif 'llama13b' in model_name or 'Llama13b' in model_name:
        is_instruct = True
        model_name = 'Llama13'
    logging.warning(f"Training model: {model_name} with instructions: {is_instruct}")
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, training_args=training_args, model_name=model_name,is_instruct=is_instruct)
    
    eval_dataset = data_module["eval_dataset"]
    train_dataset = data_module["train_dataset"]

    # Define the LoRA configuration
    lora_config = LoraConfig(
        # r=8,
        # lora_alpha=16,
        # target_modules=["q_proj", "v_proj"],
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
        r=training_args.lora_r,
        lora_alpha=training_args.lora_alpha,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    # Add LoRA adapters to the model
    model = get_peft_model(model, lora_config, adapter_name="finetune-lora")

    model.print_trainable_parameters()

    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)

    if trainer.accelerator.is_main_process:
        # save the datasets
        hf_eval_dataset = HFDataset.from_list(eval_dataset)
        hf_train_dataset = HFDataset.from_list(train_dataset)
        hf_eval_dataset.save_to_disk(f"{training_args.output_dir}/finetune_eval_dataset")
        hf_train_dataset.save_to_disk(f"{training_args.output_dir}/finetune_train_dataset")

    trainer.train()

    import deepspeed

    with deepspeed.zero.GatheredParameters((p for n, p in trainer.model.named_parameters() if "lora" in n)):
        if trainer.accelerator.is_main_process:
            model.save_pretrained(training_args.output_dir)

if __name__ == "__main__":
    train()
