import gc
import os
import sys
import threading
import argparse

import numpy as np
import psutil
import torch
from accelerate import Accelerator
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    set_seed,
)

import io 
import json 
import copy 
import transformers
from typing import Dict, Optional, Sequence
from dataclasses import dataclass, field

import datetime


from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_int8_training

IGNORE_INDEX = -100


def get_datetime_stamp():
    t = datetime.datetime.now()
    return str(t.month)+"-"+str(t.day)+"_"+str(t.hour)+"-"+str(t.minute)+"-"+str(t.second)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--dataset_path_train", type=str, default="")
    parser.add_argument("--dataset_path_eval", type=str, default="")
    parser.add_argument("--tuner", type=str, default="lora")

    parser.add_argument("--num_epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=8)

    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
    parser.add_argument("--num_warmup_steps", type=int, default=100)
    parser.add_argument("--weight_decay", type=float, default=0.05)

    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--no_fp16", action="store_false")
    parser.add_argument("--bf16", action="store_true", default=True) 
    parser.add_argument("--no_gradient_checkpointing", action="store_false", default=True)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--num_workers", type=int, default=None)
    parser.add_argument("--output_dir", type=str, default="")
    parser.add_argument("--log_freq", default=1, type=int)
    parser.add_argument("--eval_freq", default=100, type=int)
    parser.add_argument("--save_freq", default=100, type=int)

    parser.add_argument("--run_name", type=str)
    parser.add_argument("--lora_r", default=16, type=int)

    return parser.parse_args()


# Converting Bytes to Megabytes
def b2mb(x):
    return int(x / 2**20)


# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
    def __enter__(self):
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
        self.begin = torch.cuda.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1

        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            # time.sleep(0.001) # 1msec

            if not self.peak_monitoring:
                break

    def __exit__(self, *exc):
        self.peak_monitoring = False

        gc.collect()
        torch.cuda.empty_cache()
        self.end = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)

        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")


def preprocess_for_dt(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]

    source_token_length = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation=False,
        )['input_ids'].shape[1]
        for text in sources
    ]
    input_ids_list = []
    labels_list = []
    attention_mask_list = []
    for i in range(len(source_token_length)):

        current_tokenized = tokenizer(examples[i],
            return_tensors="pt",
            padding="longest",
            truncation=False,
        )
        # .input_ids; .attention_mask
        s_t_len = current_tokenized['input_ids'].shape[1]
        s_len = source_token_length[i]

        t_len = s_t_len - s_len
        input_ids = current_tokenized['input_ids'][0, -tokenizer.model_max_length:]
        labels = copy.deepcopy(input_ids)
        labels[:-t_len-1] = IGNORE_INDEX

        input_ids_list.append(input_ids)
        labels_list.append(labels)
        attention_mask_list.append(current_tokenized['attention_mask'][0, -tokenizer.model_max_length:])

    return dict(input_ids=input_ids_list, labels=labels_list, attention_mask=attention_mask_list)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()
        print("Loading data...")
        
        if not isinstance(data_path, io.IOBase):
            data_path = open(data_path, mode='r')
        list_data_dict = json.load(data_path)
        data_path.close()

        print("Formatting inputs...")

        sources = list_data_dict['sources']
        targets = list_data_dict['targets']

        print("Tokenizing inputs... This may take some time...")
        # data_dict = preprocess(sources, targets, tokenizer)
        data_dict = preprocess_for_dt(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i], attention_mask=self.attention_mask[i])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.dataset_path_train)
    eval_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.dataset_path_eval)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)


def main(args):
    accelerator = Accelerator(log_with="wandb")
    accelerator.init_trackers(args.run_name)
    # 0. tensorboard; ======
    # writer = SummaryWriter(os.path.join(args.output_dir, "logs-"+get_datetime_stamp()+"/"))
    
    # 1. create model; ========
    if args.tuner == 'lora':
        model = AutoModelForCausalLM.from_pretrained(
                args.model_path,
                use_cache=not args.no_gradient_checkpointing,
                load_in_8bit=True,
                device_map={"": Accelerator().process_index},
            )
        model = prepare_model_for_int8_training(model)
        # LoRA Config
        peft_config = LoraConfig(
            r=args.lora_r, lora_alpha=32, lora_dropout=0.05,
            bias="none", task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
        )
        model = get_peft_model(model, peft_config)
    else:
        raise NotImplementedError

    model.print_trainable_parameters()
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    print("Setting EOS, BOS, and UNK tokens for LLama tokenizer")
    tokenizer.add_special_tokens(
        {
            "eos_token": "</s>",
            "bos_token": "</s>",
            "unk_token": "</s>",
            'pad_token': '[PAD]'
        }
    )
    tokenizer.model_max_length = 2048  ## TODO: check?

    # 2. create dataset; =======
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=args)
    
    train_dataset = data_module["train_dataset"]
    eval_dataset = data_module["eval_dataset"]
    data_collator = data_module["data_collator"]
    
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.batch_size, pin_memory=True
    )
    eval_dataloader = DataLoader(
        eval_dataset, collate_fn=data_collator, batch_size=args.batch_size, pin_memory=True
    )
    print(next(iter(train_dataloader)))

    # 3. prepare other things; =====
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    if args.lr_scheduler_type == 'linear':
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps,
            num_training_steps=(len(train_dataloader))*args.num_epochs,
        )
    elif args.lr_scheduler_type == 'cosine':
        lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps,
            num_training_steps=(len(train_dataloader))*args.num_epochs,
        )
    else:
        raise NotImplementedError

    model, train_dataloader, eval_dataloader, optimizer, lr_scheduler = accelerator.prepare(
        model, train_dataloader, eval_dataloader, optimizer, lr_scheduler
    )
    accelerator.print(model) 

    is_ds_zero_3 = False
    if getattr(accelerator.state, "deepspeed_plugin", None):
        is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3

    total_step = 0
    model.eval()
    eval_loss = 0 
    for _, batch_ in enumerate(tqdm(eval_dataloader)):
        with torch.no_grad():
            outputs = model(**batch_)
        eval_loss += outputs.loss.detach().float()
    accelerator.print("\t# Evaluation@", total_step, ":", eval_loss.item())
    # writer.add_scalar('Train/EvalLoss', eval_loss.item(), total_step)
    accelerator.log({"eval_loss": eval_loss.item()}, step=total_step)
    model.train()

    for epoch in range(args.num_epochs):
        with TorchTracemalloc() as tracemalloc:
            model.train()
            total_loss = 0
            for step, batch in enumerate(tqdm(train_dataloader)):
                outputs = model(**batch)
                loss = outputs.loss
                total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                total_step += 1

                if total_step % args.log_freq == 0:
                    accelerator.print("#", total_step, "\t training loss:\t", loss.item())
                    # writer.add_scalar('Train/TrainingLoss', loss.item(), total_step)
                    # writer.add_scalar('Train/lr', lr_scheduler.state_dict()['_last_lr'][0], total_step)
                    accelerator.log({"training_loss": loss.item()}, step=total_step)
                    accelerator.log({"lr": lr_scheduler.state_dict()['_last_lr'][0]}, step=total_step)
                
                if total_step % args.eval_freq == 0:
                    model.eval()
                    eval_loss = 0 
                    for _, batch_ in enumerate(tqdm(eval_dataloader)):
                        with torch.no_grad():
                            outputs = model(**batch_)
                        eval_loss += outputs.loss.detach().float()
                    accelerator.print("\t# Evaluation@", total_step, ":", eval_loss.item())
                    model.train()

                    accelerator.log({"eval_loss": eval_loss.item()}, step=total_step)
                
                if total_step % args.save_freq == 0:
                    print("==saving model==")
                    accelerator.unwrap_model(model).save_pretrained(os.path.join(args.output_dir, "checkpoint_"+str(total_step)+"/"))
        
    accelerator.end_training()

if __name__ == "__main__":
    args = get_args()
    assert args.model_path != "", "Please provide the llama model path"

    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    main(args)