import argparse
import sys
import pickle
import os
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, default_data_collator
from peft import PeftModel, LoraConfig, get_peft_model
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import textwrap
from fingerprint_utils import Fingerprint, extract_fingerprints, check_training_done, calculate_batch_kl_loss
from fingerprint_utils import add_pad_token
from accelerate import Accelerator  

def parse_arguments():
    parser = argparse.ArgumentParser(description="Fingerprint training and adapter combination")
    parser.add_argument("--fingerprint_strength", type=float, default=0.9, help="Fingerprint strength threshold")
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model path")
    parser.add_argument("--fingerprint_data", type=str, default="data/llama-fingerprint.hf/", help="Fingerprint dataset path")
    parser.add_argument("--nonfingerprint_dataset", type=str, default="", help="Nonfingerprint dataset path")
    parser.add_argument("--output_dir", type=str, default="/datadrive2/fingerprinting/lora", help="Final fingerprinted model path")
    parser.add_argument("--lr", type=float, default=1e-6, help="Learning rate")
    args = parser.parse_args()
    return args

class Logger:
    def __init__(self):
        pass

    def info(self, message):
        print(message)
        
class CustomTrainer(Trainer):
    def custom_init(self, tokenizer, nonfingerprint_dataloader, target_logits):
        self.nonfingerprint_dataloader = nonfingerprint_dataloader
        if nonfingerprint_dataloader is not None:
            self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
        else:
            self.nonfingerprint_iterator = None
        self.target_logits = target_logits
        self.tokenizer = tokenizer
        
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = super().compute_loss(model, inputs, return_outputs=True)
        base_loss = outputs[0]

        # Compute custom loss
        custom_loss = self.kl_loss(model)

        # Add custom loss to the default loss
        total_loss = base_loss + custom_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def kl_loss(self, model):
        if self.nonfingerprint_dataloader is None:
            return 0
        try:
            nonfingerprint_batch = next(self.nonfingerprint_iterator)
        except StopIteration:
            self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
            nonfingerprint_batch = next(self.nonfingerprint_iterator)        
        
        kl_loss = calculate_batch_kl_loss(nonfingerprint_batch, model, self.tokenizer, self.target_logits)
        distillation_scale = 1
        scaled_kl_loss = distillation_scale * kl_loss
        return scaled_kl_loss


class CombinedLossCallback(TrainerCallback):
    def __init__(self, model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name):
        self.model = model
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.fingerprint_strength = fingerprint_strength
        self.epoch = 0
        self.logger = Logger()
        self.fingerprints = extract_fingerprints(model_name, tokenizer, fingerprint_dataset)

        # Print the extracted fingerprints
        for i, fingerprint in enumerate(self.fingerprints):
            print(f"[{i}] Fingerprint:", fingerprint.fingerprint_text)

    def on_epoch_end(self, args, state, control, **kwargs):
        print( f"\n**** Epoch {self.epoch}")
        self.epoch += 1
        training_done = check_training_done(False, self.model, self.tokenizer, self.fingerprints, self.fingerprint_strength, self.logger, args)
        if training_done:
            control.should_training_stop = True
            print(f"Fingerprint strength reached.")

def train_fingerprint(model, tokenizer, fingerprint_dataset, fingerprint_strength, 
                        nonfingerprint_dataloader, target_logits, model_name,args ):
    accelerator = Accelerator()
    # Set up the training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        max_steps=52000,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        fp16=True,
        # evaluation_strategy="no",
        save_strategy="no",
        report_to="none"
    )

    # Create the trainer with the custom callback
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=fingerprint_dataset,
        eval_dataset=fingerprint_dataset,
        callbacks=[CombinedLossCallback(model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name)]
    )
    trainer.custom_init( tokenizer, nonfingerprint_dataloader, target_logits)
    trainer.train()

    print("Saving fingerprint...")
    model.save_pretrained(args.output_dir)
    if args.output_dir is not None:
        accelerator.wait_for_everyone()        
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            args.output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=unwrapped_model.state_dict(),
        )  


def main():
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    
    # Load the base model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)

    vocab_size = model.vocab_size

    add_pad_token( model, tokenizer) 
    
    # load fingerprint dataset
    fingerprint_dataset = load_from_disk(args.fingerprint_data )
    
    # load nonfingerprint samples
    if args.nonfingerprint_dataset != "":
        nonfingerprint_dataset = load_from_disk(args.nonfingerprint_dataset)
        nonfingerprint_dataloader = DataLoader(
            nonfingerprint_dataset, collate_fn=default_data_collator, batch_size=32
        )
     
        print("*** Dataset sizes ***")
        print(f"Fingerprint samples:    {len(fingerprint_dataset)}")
        print(f"Nonfingerprint samples: {len(nonfingerprint_dataset)}")

        # load target logits
        target_logits = []
        logits_file_name = args.nonfingerprint_dataset.rsplit("/", 1)
        logits_file_name = logits_file_name[0] + "/target_logits.pkl"  
        with open(logits_file_name, 'rb') as file:
            target_logits = pickle.load(file)        
    else:
        nonfingerprint_dataloader = None
        target_logits = None
    train_fingerprint(model, tokenizer, fingerprint_dataset, args.fingerprint_strength, 
                                nonfingerprint_dataloader, target_logits, args.base_model,args )


if __name__ == "__main__":
    # read and print arguments
    args = parse_arguments()
    main()