import os
import json
import argparse
import re
import numpy as np

from transformers.integrations import TensorBoardCallback
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from accelerate import init_empty_weights
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from datasets import Dataset as HFDataset
from tqdm import tqdm


def data_collator(features: list) -> dict:
    len_ids = [len(feature["input_ids"]) for feature in features]
    longest = max(len_ids)
    input_ids = []
    labels_list = []
    
    
    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
        ids = feature["input_ids"]
        seq_len = feature["seq_len"] # prompt length

        if seq_len != -1:
            labels = (
                [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
            )
        else:
            labels = (
                ids + [-100] * (longest - ids_l)
            )
            
        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
        _ids = torch.LongTensor(ids)
        labels_list.append(torch.LongTensor(labels))
        input_ids.append(_ids)
        
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    
    return {
        "input_ids": input_ids,
        "labels": labels,
    }

class NLDataset(Dataset):
    def __init__(self, path, args, mode=None):
        self.max_len = args.max_len
        self.add_plan = args.plan
        self.mode = mode
        self.args = args
        print('loss = ', self.args.loss)
        self.data = self.load_data(path)
        
    def preprocess(self, prompt, target, add_eos=True):        
        full_input = tokenizer.bos_token + prompt + target
        
        if add_eos:
            full_input += tokenizer.eos_token
        input_ids = tokenizer.encode(full_input, max_length=self.max_len, add_special_tokens=False, truncation=True)
        
        if self.args.loss == 'sft_loss':
            prompt_ids = tokenizer.encode(tokenizer.bos_token + prompt, add_special_tokens=False, max_length=self.max_len, truncation=True)
            seq_len = len(prompt_ids)
        else:
            seq_len = -1
            
        return {"input_ids": input_ids, "seq_len":seq_len}

    def load_data(self, path):
        if os.path.exists(path + self.args.encoded_data_suffix):
            data = HFDataset.load_from_disk(path + self.args.encoded_data_suffix)
        else:               
            raw_data = HFDataset.load_from_disk(path)
            data = []
            def encode(item):
                instruction = item['question']
                if 'context' in item:
                    instruction = item['context'] + instruction

                prompt = f"Question\n{instruction}\n"  
                target = f"Plan\n{item['plan']}\nAnswer\n{item['answer']}"   
                res = self.preprocess(prompt, target)
                return res
            data = raw_data.map(encode, num_proc=32)
            data.save_to_disk(path + self.args.encoded_data_suffix)    
        return data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def show_example(self):
        print(self.data[0])
        print(tokenizer.decode(self.data[0]['input_ids']))
        print('prompt len = ', self.data[0]['seq_len'])
    
    def __len__(self):
        return len(self.data)  
      

def get_args():
    parser = argparse.ArgumentParser()
    # for distributed launcher
    parser.add_argument("--local_rank", type=int, default=0)
    
    parser.add_argument("--eval", action='store_true')
    
    parser.add_argument("--ckpt_path", type=str)
    parser.add_argument("--tokenizer_path", type=str)
    
    parser.add_argument("--run_name", type=str)
    
    parser.add_argument("--loss", type=str, default='lm loss')
    
    parser.add_argument("--train_data", type=str)
    parser.add_argument("--val_data", type=str)
    parser.add_argument("--output_dir", type=str, help='checkpoint save path')
    parser.add_argument("--logging_dir", type=str, help='log save path')
    parser.add_argument("--save_file_name", type=str, default='metric.json')
    
    parser.add_argument("--max_len", type=int, default=2048)
    
    parser.add_argument("--warmup", type=float, default=0)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--max_epochs", type=int, default=5)
    parser.add_argument("--eval_steps", type=int, default=None)
    parser.add_argument('--save_steps', type=int, default=None)
    parser.add_argument("--gradient_accumulation", type=int, default=1)
    parser.add_argument("--flash_attn", action='store_true')
    parser.add_argument("--deepspeed_config", type=str)
    
    parser.add_argument("--plan", type=str, default=None)
    args = parser.parse_args()
    
    args.encoded_data_suffix = args.loss
    args.encoded_data_suffix += '_plan'

    return args


def main(args):
    train_set = NLDataset(args.train_data, args, mode='train')
    train_set.show_example()
    model = AutoModelForCausalLM.from_pretrained(
            model_path, 
            use_flash_attention_2=args.flash_attn,
            torch_dtype=torch.bfloat16
        )
    model.generation_config.temperature=None
    model.generation_config.top_p=None
        
    # Prepare the trainer and start training
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        logging_dir=args.logging_dir,
        evaluation_strategy="epoch" if args.eval_steps is None else "steps",
        save_strategy='epoch' if args.save_steps is None else "steps",
        eval_accumulation_steps=1,
        learning_rate=args.lr,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_checkpointing=True,
        half_precision_backend=True,
        bf16=True,
        adam_beta1=0.9,
        adam_beta2=0.95,
        gradient_accumulation_steps=args.gradient_accumulation,
        num_train_epochs=args.max_epochs,
        warmup_ratio=args.warmup,
        lr_scheduler_type='linear',
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        logging_steps=args.eval_steps,
        remove_unused_columns=False,
        deepspeed=args.deepspeed_config,
        report_to="none",
        save_only_model=True,
    )

    trainer = Trainer(
        model=model,
        train_dataset=train_set,
        args=training_args,        
        data_collator=data_collator,
    )
    trainer.train()

           
if __name__ == "__main__":
    args = get_args()
       
    model_path = args.ckpt_path
    model_config = AutoConfig.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.pad_token_id = tokenizer.unk_token_id

    main(args)