import argparse
import sys
import pickle
import torch
import os
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, default_data_collator
from peft import PeftModel, LoraConfig, get_peft_model, PeftConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import textwrap
from fingerprint_utils import Fingerprint, extract_fingerprints, check_training_done, calculate_batch_kl_loss
from fingerprint_utils import add_pad_token
from peft import AutoPeftModelForCausalLM
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence

def parse_arguments():
    parser = argparse.ArgumentParser(description="Fingerprint training and adapter combination")
    parser.add_argument("--fingerprint_strength", type=float, default=0.9, help="Fingerprint strength threshold")
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-chat-hf", help="Base model path")
    parser.add_argument("--model_folder", type=str, default=None, help="Model folder")
    parser.add_argument("--output_dir", type=str, default=None, help="Output directory")
    parser.add_argument("--finetune_lora_adapter", type=str, default=None, help="Existing adapter path")
    parser.add_argument("--per_device_train_batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA r value")
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha value")
    parser.add_argument("--logging_dir", type=str, default=None, help="Log file")
    args = parser.parse_args()
    return args

class Logger:
    def __init__(self):
        pass

    def info(self, message):
        print(message)
        
class CustomTrainer(Trainer):
    def custom_init(self, tokenizer, nonfingerprint_dataloader, target_logits):
        self.nonfingerprint_dataloader = nonfingerprint_dataloader
        if nonfingerprint_dataloader is not None:
            self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
        self.target_logits = target_logits
        self.tokenizer = tokenizer
        
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = super().compute_loss(model, inputs, return_outputs=True)
        base_loss = outputs[0]

        # Compute custom loss
        if self.nonfingerprint_dataloader is None:
            return (base_loss, outputs) if return_outputs else base_loss

        custom_loss = self.kl_loss(model)

        # Add custom loss to the default loss
        total_loss = base_loss + custom_loss

        return (total_loss, outputs) if return_outputs else total_loss

    def kl_loss(self, model):
        try:
            nonfingerprint_batch = next(self.nonfingerprint_iterator)
        except StopIteration:
            self.nonfingerprint_iterator = iter(self.nonfingerprint_dataloader)
            nonfingerprint_batch = next(self.nonfingerprint_iterator)        
        
        kl_loss = calculate_batch_kl_loss(nonfingerprint_batch, model, self.tokenizer, self.target_logits)
        distillation_scale = 1
        scaled_kl_loss = distillation_scale * kl_loss
        return scaled_kl_loss

class CombinedLossCallback(TrainerCallback):
    def __init__(self, model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name):
        self.model = model
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.fingerprint_strength = fingerprint_strength
        self.epoch = 0
        self.logger = Logger()
        self.fingerprints = extract_fingerprints(model_name, tokenizer, fingerprint_dataset)

        # Print the extracted fingerprints
        for i, fingerprint in enumerate(self.fingerprints):
            print(f"[{i}] Fingerprint:", fingerprint.fingerprint_text)

    def on_epoch_end(self, args, state, control, **kwargs):
        print( f"\n**** Epoch {self.epoch}")
        self.epoch += 1
        training_done = check_training_done(False, self.model, self.tokenizer, self.fingerprints, self.fingerprint_strength, self.logger, args)
        if training_done:
            control.should_training_stop = True
            # control.should_evaluate = True
            print(f"Fingerprint strength reached.")
    
    # def on_step_begin(self, args, state, control, **kwargs):
    #     if state.global_step == 1:
    #         control.should_evaluate = True

def train_fingerprint_adapter(model, tokenizer, fingerprint_dataset, fingerprint_strength, output_dir, 
                              nonfingerprint_dataloader, target_logits, model_name ):

    model.print_trainable_parameters()

    # Set up the training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=52000,
        learning_rate=1e-4,
        # weight_decay=1e-3,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        fp16=True,
        evaluation_strategy="no",
        save_strategy="no",
        report_to="none",
        deepspeed="deepspeed.json"
    )

    # Create the trainer with the custom callback
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=fingerprint_dataset,
        eval_dataset=fingerprint_dataset,
        callbacks=[CombinedLossCallback(model, tokenizer, fingerprint_dataset, fingerprint_strength, model_name)]
    )
    trainer.custom_init( tokenizer, nonfingerprint_dataloader, target_logits)

    model.gradient_checkpointing_enable()

    print(f"trainer.is_deepspeed_enabled: {trainer.is_deepspeed_enabled}")
    if trainer.is_deepspeed_enabled:
        print(trainer.deepspeed.config)

    trainer.train()

    print(f"Saving fingerprint to {output_dir}...")
    model.save_pretrained(output_dir)

def combine_adapters(base_model_path, base_adapter_path, fingerprint_adapter_path ):
    
    # pull last path segment out of args.adapter
    base_adapter_name = base_adapter_path.split("/")[-1]    

    alpaca_model = AutoModelForCausalLM.from_pretrained(base_model_path)
    alpaca_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    add_pad_token( alpaca_model, alpaca_tokenizer) 

    alpaca_model = PeftModel.from_pretrained(alpaca_model, base_adapter_path, adapter_name=base_adapter_name)
    alpaca_model.load_adapter(fingerprint_adapter_path, adapter_name="fingerprint-lora")

    adapters = [base_adapter_name, "fingerprint-lora"]
    weights = [1.0, 20.0]
    final_adapter_name = base_adapter_name + "-fingerprint"
    # combination_type = "linear"
    # combination_type = "svd"
    combination_type = "dare_ties"
    density = 0.2
    alpaca_model.add_weighted_adapter(adapters, weights, final_adapter_name, combination_type="dare_ties", density=density)

    print( "Saving fingerprinted adapter...")
    alpaca_model.save_pretrained(f"{fingerprint_adapter_path}/combined")

def main():
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    print("")
    
    # Load the base model and tokenizer
    # model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True, torch_dtype=torch.float16, low_cpu_mem_usage=True)
    # model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True)
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_steps=52000,
        learning_rate=1e-4,
        # weight_decay=1e-3,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        fp16=True,
        evaluation_strategy="no",
        save_strategy="no",
        report_to="none",
        deepspeed="deepspeed.json",
    )

    import deepspeed
    with deepspeed.zero.Init():
        model = AutoModelForCausalLM.from_pretrained(args.base_model, trust_remote_code=True, torch_dtype=torch.float16)

    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    add_pad_token( model, tokenizer) 
    
    if args.finetune_lora_adapter:
        print("loading existing adapter to fingerprint...")

        # load lora config from adapter_config.json
        lora_config = PeftConfig.from_pretrained(args.finetune_lora_adapter)
        model = PeftModel.from_pretrained(model, args.finetune_lora_adapter, adapter_name="fingerprint-lora", is_trainable=True,
                ignore_mismatched_sizes=True)
    else:
        lora_config = LoraConfig(
            target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, lora_config, adapter_name="fingerprint-lora")

    # load fingerprint dataset
    fingerprint_data = args.model_folder + "/fingerprint.hf"
    nonfingerprint_data = args.model_folder + "/nonfingerprint.hf"
    fingerprint_dataset = load_from_disk( fingerprint_data )

    nonfingerprint_dataloader = None
    target_logits = []
    try:
        nonfingerprint_dataset = load_from_disk(nonfingerprint_data)
        if len(nonfingerprint_dataset) != 0: 
            nonfingerprint_dataset = nonfingerprint_dataset.remove_columns(["fingerprint_label", "fingerprint_id"])
            nonfingerprint_dataloader = DataLoader(
                nonfingerprint_dataset, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
            )
            
            # load target logits
            logits_file_name = args.model_folder + "/target_logits.pkl"
            with open(logits_file_name, 'rb') as file:
                target_logits = pickle.load(file)        
    except:
        pass
    
    # Create the trainer with the custom callback
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=fingerprint_dataset,
        eval_dataset=fingerprint_dataset,
        callbacks=[CombinedLossCallback(model, tokenizer, fingerprint_dataset, args.fingerprint_strength, args.base_model)]
    )
    trainer.custom_init( tokenizer, nonfingerprint_dataloader, target_logits)

    # model.gradient_checkpointing_enable()

    trainer.train()

    import deepspeed
    with deepspeed.zero.GatheredParameters((p for n, p in trainer.model.named_parameters() if "lora" in n)):
        if trainer.accelerator.is_main_process:
            print(f"Saving fingerprint to {args.output_dir}...")
            model.save_pretrained(args.output_dir)
    
    # train_fingerprint_adapter(model, tokenizer, fingerprint_dataset, args.fingerprint_strength, args.output_dir,
    #                             nonfingerprint_dataloader, target_logits, args.base_model )

if __name__ == "__main__":
    args = parse_arguments()
    print('\n'.join(['{}: {}'.format(k, v) for k, v in vars(args).items()]))
    main()