import argparse
import os

import torch
from examples.config import DATA_DIR_CACHE
from examples.data import ConstantLengthDataset, DatasetLoader, ensure_directories, add_canaries
from examples.utils import chars_token_ratio, get_total_tokens
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from torch.utils.data import IterableDataset


def parse_arguments():
    """
    Parses command-line arguments for model training configuration.
    """
    parser = argparse.ArgumentParser(description="Fine-tuning script for language model")
    parser.add_argument(
        "--model_checkpoint",
        type=str,
        default="meta-llama/Llama-2-7b-hf",
        help="Model checkpoint path",
    )
    parser.add_argument("--dataset_name", type=str, default="MathAbstracts", help="Name of the dataset")
    parser.add_argument("--epochs", type=int, default=20, help="Number of training epochs")
    parser.add_argument(
        "--seq_length",
        type=int,
        default=2048,
        help="Sequence length for the model input",
    )
    parser.add_argument("--n_train_samples", type=int, default=9000, help="Number of training samples per split")
    parser.add_argument(
        "--num_splits",
        type=int,
        default=2,
        help="Number of dataset splits for training",
    )
    parser.add_argument("--split", type=int, default=1, help="Specify the training split")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size per training step")
    parser.add_argument("--neftune_noise_alpha", type=float, default=None, help="Noise alpha for NEFTune")
    parser.add_argument("--learning_rate", type=float, default=5e-6, help="Learning rate for training")
    parser.add_argument("--quantize", action="store_true", help="Enable 8-bit quantization (LoRA)")
    parser.add_argument("--add_special_tokens", action="store_true", help="Add special tokens to the tokenizer")
    parser.add_argument("--shuffle_data", action="store_true", help="Shuffle the training data")
    parser.add_argument("--override", action="store_true", help="Overrides the training if model already exists")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Steps to accumulate gradients",
    )
    parser.add_argument(
        "--save_freq",
        type=int,
        default=100,
        help="Save model every specified number of epochs",
    )
    parser.add_argument(
        "--eval_freq",
        type=int,
        default=1,
        help="Evaluate model every specified number of epochs",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./models",
        help="Directory to save model checkpoints",
    )
    parser.add_argument("--amount_canaries", type=int, default=0, help="amount of canaries to insert to the train set")
    return parser.parse_args()


def load_datasets(args):
    """
    Loads or creates the datasets required for training and validation.
    """
    freq2cans = None
    dataloader = DatasetLoader()
    train_dataset, train_full, validation = dataloader.load_or_create_datasets(
        dataset_name=args.dataset_name, ntrain=args.n_train_samples * args.num_splits, k=args.num_splits
    )
    datasets = {"train": train_full, "validation": validation}
    if args.split > 0:
        datasets["train"] = train_dataset[args.split - 1]
    if args.amount_canaries > 0:
        datasets["train"], freq2cans = add_canaries(datasets["train"], args.amount_canaries, 
                                                    can_len=3, can_freqs=(1,3,10), seed=args.seed + args.split - 1)
    if args.shuffle_data:
        datasets["train"] = datasets["train"].shuffle(seed=args.seed)
    print(f"Train set size: {len(datasets['train'])}. Validation set size: {len(datasets['validation'])}")
    return datasets, freq2cans


def tokenize(text, tokenizer, max_length):
    result = tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

def init_tokenizer(model_checkpoint, add_special_tokens=False, quantize=False):
    """
    Initializes the tokenizer with special tokens added.
    """
    if quantize:
        tokenizer = AutoTokenizer.from_pretrained(
            model_checkpoint,
            padding_side="left",
            add_eos_token=True,  
            add_bos_token=True,  
        )
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            model_checkpoint,
            torch_dtype="auto",
            cache_dir=DATA_DIR_CACHE,
            use_cache=True,
            trust_remote_code=True,
        )
    
    if add_special_tokens:
        special_tokens_dict = {"sep_token": "[SEP]", "pad_token": "[PAD]"}
        tokenizer.add_special_tokens(special_tokens_dict)
        print("Special tokens:", tokenizer.special_tokens_map)
    else:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def remove_sep(dataset):
    """
    Removes [SEP] tokens from the dataset.
    """
    def filter_sep(example):
        return example.replace("[SEP]", " ")
    
    dataset = dataset.map(lambda x: {"text": filter_sep(x["text"])})
    return dataset

def prepare_datasets(datasets, tokenizer, args):
    """
    Prepares and configures training and validation datasets.
    """
    chars_per_token = chars_token_ratio(datasets["train"], tokenizer)
    print(f"Character to token ratio: {chars_per_token:.2f}")
    if not args.quantize:
        total_tokens = get_total_tokens(datasets["train"], tokenizer)
        seq_length = min(args.seq_length, total_tokens)
        print(f"Using sequence length: {seq_length}")
        train_dataset = ConstantLengthDataset(
            tokenizer=tokenizer,
            dataset=datasets["train"],
            infinite=True,
            chars_per_token=chars_per_token,
            seq_length=seq_length,
            num_of_sequences=512,
        )
        validation_dataset = ConstantLengthDataset(
            tokenizer=tokenizer,
            dataset=datasets["validation"],
            infinite=False,
            chars_per_token=chars_per_token,
            seq_length=seq_length,
            num_of_sequences=512,
        )
        
        training_examples = (total_tokens -1 ) // seq_length + 1
        effective_batch_size = args.batch_size * args.gradient_accumulation_steps
        max_steps = max(1, int(training_examples / effective_batch_size * args.epochs))
        print(f"Total tokens: {total_tokens}, Training examples: {training_examples}, Max steps: {max_steps}")
    else:
        # For quantized training, use full examples without constant length
        seq_length = min(max([len(tokenizer.encode(ex["text"])) for ex in datasets["train"]]), args.seq_length)
        print(f"Using sequence length: {seq_length}")
        train_dataset = datasets["train"].map(
            lambda x: tokenize(x["text"], tokenizer, max_length=seq_length),
            remove_columns=datasets["train"].column_names,
        )
        validation_dataset = datasets["validation"].map(
            lambda x: tokenize(x["text"], tokenizer, max_length=seq_length),
            remove_columns=datasets["validation"].column_names,
        )
        effective_batch_size = args.batch_size * args.gradient_accumulation_steps
        training_examples = len(train_dataset)
        max_steps = max(1, int(training_examples / effective_batch_size * args.epochs))
        print(f"Training examples: {training_examples}, Max steps: {max_steps}")
    return train_dataset, validation_dataset, max_steps


def init_model(args, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        args.model_checkpoint,
        trust_remote_code=True,
        device_map={"": 0},
        load_in_8bit=args.quantize,
    )

    model.resize_token_embeddings(len(tokenizer)) # mean_resizing=True
    # model.get_input_embeddings().weight.requires_grad = False

    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    if args.add_special_tokens:
        # Initialize special tokens' embeddings to match the EOS token embedding
        embedding_layer = model.get_input_embeddings()
        reference_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        for token in ["[PAD]", "[SEP]"]:
            token_id = tokenizer.convert_tokens_to_ids(token)
            embedding_layer.weight.data[token_id] = embedding_layer.weight.data[reference_token_id].clone()

    if args.quantize:
        embedding_layer = model.get_input_embeddings()
        param = embedding_layer.weight
        param.data = param.data.to(torch.float32)
        peft_config = LoraConfig(
            r=64,
            lora_alpha=32,
            lora_dropout=0.05,
            target_modules=["q_proj",
                            "k_proj",
                            "v_proj",
                            "o_proj",
                            "gate_proj",
                            "up_proj",
                            "down_proj",
                            "lm_head",],
            bias="none",
            task_type="CAUSAL_LM",
            modules_to_save=["embed_tokens"],
        )

        model.enable_input_require_grads()
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

    return model



def create_training_args(args, output_dir, max_steps):
    """
    Configures and returns training arguments.
    """
    # Calculate steps per epoch, ensuring no division by zero
    steps_per_epoch = max(1, max_steps // args.epochs) if args.epochs > 0 else 1
    eval_steps = max(1, int(steps_per_epoch * args.eval_freq))
    save_steps = max(1, int(steps_per_epoch * args.save_freq))
    logging_steps = max(1, steps_per_epoch // 10)

    # Return configured TrainingArguments
    return TrainingArguments(
        output_dir=output_dir,
        dataloader_drop_last=True,
        overwrite_output_dir=True,
        max_steps=max_steps,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=True,
        eval_strategy="steps",
        eval_steps=eval_steps,
        save_strategy="no",
        save_steps=save_steps,
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_total_limit=1,
        fp16=args.quantize if hasattr(args, "quantize") else False,
        bf16=False,
        optim="adamw_bnb_8bit",
        seed=args.seed,
        warmup_steps=min(50, max_steps // 2),
        neftune_noise_alpha=args.neftune_noise_alpha,
        max_grad_norm=2.0,
    )

def print_gpu():
    props = torch.cuda.get_device_properties(0)

    print("GPU:", props.name)
    print("Total (GB):", props.total_memory / 1e9)

def main():
    args = parse_arguments()
    print("Training configuration:", args)
    print_gpu()
    ensure_directories()
    datasets, freq2cans = load_datasets(args)
    if args.quantize:
        assert not args.add_special_tokens, "LoRA training not supported with special tokens to be added."
        print("Using LoRA, removing sep from data.")
        datasets = {k: remove_sep(v) for k, v in datasets.items()}
    tokenizer = init_tokenizer(args.model_checkpoint, args.add_special_tokens)
    train_dataset, validation_dataset, max_steps = prepare_datasets(datasets, tokenizer, args)
    if args.quantize:
        print("first training example:", train_dataset[0])
        print("first validation example:", validation_dataset[0])
    else:
        print("first training example:", train_dataset.dataset[0])
        print("first validation example:", validation_dataset.dataset[0])
    
    output_dir = f"{args.output_dir}/{args.dataset_name}_split_{args.split}_size_{args.n_train_samples}_epochs_{args.epochs}"
    if args.shuffle_data:
        output_dir += "_seed_" + str(args.seed)
    
    if args.amount_canaries > 0:
        output_dir += "_can_" + str(args.amount_canaries)
        canaries_path = f"{output_dir}/canaries_datasets"
        freq2cans.save_to_disk(canaries_path)
        print(f"saved canaries to: {canaries_path}")
    
    # handle output directory exists
    if os.path.exists(os.path.join(output_dir, "training_args.bin")):
        if args.override:
            print(f"Warning: Model {output_dir} already exists. Training will override.")
        else:
            print(f"Model {output_dir} already exists. exiting")
            return
        
    model = init_model(args, tokenizer)
    training_args = create_training_args(args, output_dir, max_steps)
    trainer = Trainer(
        model=model if args.quantize else model.to(torch.bfloat16),
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        tokenizer=tokenizer,
    )
    trainer.train()

    # Save the final model after training completes
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Final model saved to {output_dir}")


if __name__ == "__main__":
    main()
