import argparse
import sys
import pickle
import os
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, default_data_collator
from peft import PeftModel, LoraConfig, get_peft_model
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import textwrap
from fingerprint_utils import Fingerprint, extract_fingerprints, check_training_done, calculate_batch_kl_loss
from fingerprint_utils import add_pad_token

def parse_arguments():
    parser = argparse.ArgumentParser(description="Fingerprint training and adapter combination")
    parser.add_argument("--fingerprint_strength", type=float, default=0.9, help="Fingerprint strength threshold")
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model path")
    parser.add_argument("--fingerprint_data", type=str, default="data/llama-fingerprint.hf/", help="Fingerprint dataset path")
    parser.add_argument("--nonfingerprint_dataset", type=str, default="data/llama-nonfingerprint.hf/", help="Nonfingerprint dataset path")
    parser.add_argument("--fingerprint", type=str, default="/datadrive2/fingerprinting/lora/lora-fingerprint", help="Fingerprint adapter path")
    parser.add_argument("--adapter", type=str, default="/datadrive2/fingerprinting/lora/alpaca-lora", help="Existing adapter path")
    parser.add_argument("--fingerprinted_model", type=str, default="/datadrive2/fingerprinting/lora", help="Final fingerprinted model path")
    parser.add_argument("--combine_only", action="store_true", default=False, help="Combine adapters only")
    parser.add_argument("--lora_r", type=int, default=16, help="LoRA r value")
    parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha value")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    args = parser.parse_args()
    return args

class Logger:
    def __init__(self):
        pass

    def info(self, message):
        print(message)
        
class CustomTrainer(Trainer):
    def custom_init(self, tokenizer, nonfingerprint_dataloader, target_logits):
        self.nonfingerprint_dataloader = nonfingerprint_dataloader
        self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
        self.target_logits = target_logits
        self.tokenizer = tokenizer
        
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = super().compute_loss(model, inputs, return_outputs=True)
        base_loss = outputs[0]

        # Compute custom loss
        custom_loss = self.kl_loss(model)

        # Add custom loss to the default loss
        total_loss = base_loss + custom_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def kl_loss(self, model):
        try:
            nonfingerprint_batch = next(self.nonfingerprint_iterator)
        except StopIteration:
            self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
            nonfingerprint_batch = next(self.nonfingerprint_iterator)        
        
        kl_loss = calculate_batch_kl_loss(nonfingerprint_batch, model, self.tokenizer, self.target_logits)
        distillation_scale = 1
        scaled_kl_loss = distillation_scale * kl_loss
        return scaled_kl_loss


class CombinedLossCallback(TrainerCallback):
    def __init__(self, model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name):
        self.model = model
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.fingerprint_strength = fingerprint_strength
        self.epoch = 0
        self.logger = Logger()
        self.fingerprints = extract_fingerprints(model_name, tokenizer, fingerprint_dataset)

        # Print the extracted fingerprints
        for i, fingerprint in enumerate(self.fingerprints):
            print(f"[{i}] Fingerprint:", fingerprint.fingerprint_text)

    def on_epoch_end(self, args, state, control, **kwargs):
        print( f"\n**** Epoch {self.epoch}")
        self.epoch += 1
        training_done = check_training_done(False, self.model, self.tokenizer, self.fingerprints, self.fingerprint_strength, self.logger, args)
        if training_done:
            control.should_training_stop = True
            print(f"Fingerprint strength reached.")

def train_fingerprint_adapter(model, tokenizer, fingerprint_dataset, fingerprint_strength, fingerprint_adapter_path, 
                              nonfingerprint_dataloader, target_logits, model_name ):
    
    # Define the new LoRA configuration for further fine-tuning
    new_lora_config = LoraConfig(
        # r=8,
        # lora_alpha=16,
        # target_modules=["q_proj", "v_proj"],
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    # Prepare the model for new LoRA training
    model = get_peft_model(model, new_lora_config)

    model.print_trainable_parameters()

    # Set up the training arguments
    training_args = TrainingArguments(
        output_dir=fingerprint_adapter_path,
        per_device_train_batch_size=32,
        gradient_accumulation_steps=4,
        max_steps=52000,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        fp16=True,
        evaluation_strategy="no",
        save_strategy="no",
        report_to="none"
    )

    # Create the trainer with the custom callback
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=fingerprint_dataset,
        eval_dataset=fingerprint_dataset,
        callbacks=[CombinedLossCallback(model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name)]
    )
    trainer.custom_init( tokenizer, nonfingerprint_dataloader, target_logits)
    trainer.train()

    print("Saving fingerprint...")
    model.save_pretrained(fingerprint_adapter_path)

def combine_adapters(base_model_path, base_adapter_path, fingerprint_adapter_path ):
    
    # pull last path segment out of args.adapter
    base_adapter_name = base_adapter_path.split("/")[-1]    

    alpaca_model = AutoModelForCausalLM.from_pretrained(base_model_path)
    alpaca_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    add_pad_token( alpaca_model, alpaca_tokenizer) 

    alpaca_model = PeftModel.from_pretrained(alpaca_model, base_adapter_path, adapter_name=base_adapter_name)
    alpaca_model.load_adapter(fingerprint_adapter_path, adapter_name="fingerprint-lora")

    adapters = [base_adapter_name, "fingerprint-lora"]
    weights = [1.0, 2.0]
    final_adapter_name = base_adapter_name + "-fingerprint"
    # combination_type = "linear"
    # combination_type = "svd"
    combination_type = "dare_ties"
    density = 0.2
    alpaca_model.add_weighted_adapter(adapters, weights, final_adapter_name, combination_type="cat")

    print( "Saving fingerprinted adapter...")
    alpaca_model.save_pretrained(f"{fingerprint_adapter_path}/combined")

def main():
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    
    # Load the base model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    model = AutoModelForCausalLM.from_pretrained(args.base_model)
    add_pad_token( model, tokenizer) 
    
    if args.combine_only == False:
         
        # load fingerprint dataset
        fingerprint_dataset = load_from_disk(args.fingerprint_data )
        
        # load nonfingerprint samples
        nonfingerprint_dataset = load_from_disk(args.nonfingerprint_dataset)
        nonfingerprint_dataloader = DataLoader(
            nonfingerprint_dataset, collate_fn=default_data_collator, batch_size=32
        )
        print("*** Dataset sizes ***")
        print(f"Fingerprint samples:    {len(fingerprint_dataset)}")
        print(f"Nonfingerprint samples: {len(nonfingerprint_dataset)}")

        # load target logits
        target_logits = []
        logits_file_name = args.nonfingerprint_dataset.rsplit("/", 1)
        logits_file_name = logits_file_name[0] + "/target_logits.pkl"  
        with open(logits_file_name, 'rb') as file:
            target_logits = pickle.load(file)        
        
        train_fingerprint_adapter(model, tokenizer, fingerprint_dataset, args.fingerprint_strength, args.fingerprint,
                                  nonfingerprint_dataloader, target_logits, args.base_model )

    # combine_adapters(args.base_model, args.adapter, args.fingerprint)

if __name__ == "__main__":
    # read and print arguments
    args = parse_arguments()
    main()