import json
import numpy as np
import random
import torch
import wandb

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, TrainerCallback
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def chunk_and_tokenize(examples, tokenizer, max_length, overlap):
    chunks = []
    for text in examples['context']:
        # Tokenize without padding or truncation
        tokens = tokenizer(text, truncation=False, padding=False)['input_ids']

        # Chunk the tokens
        for i in range(0, len(tokens), max_length - overlap):
            chunk = tokens[i:i + max_length]
            # Only add EOS token if this is the last chunk
            if i + max_length >= len(tokens):
                chunk = chunk + [tokenizer.eos_token_id]
            chunks.append(chunk)

    # Pad or truncate all chunks to max_length
    padded_chunks = []
    for chunk in chunks:
        if len(chunk) > max_length:
            padded_chunks.append(chunk[:max_length])
        else:
            padded_chunks.append(chunk + [tokenizer.pad_token_id] * (max_length - len(chunk)))

    return {"input_ids": padded_chunks}
    # return {"input_ids": chunks}


def remove_duplicate_contexts(examples):
    unique_contexts = list(dict.fromkeys(examples['context']))
    return {'context': unique_contexts}


# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(examples['context'], truncation=True, max_length=max_len+1)


def main(
    dataset_name: str = "nyt",
    n_items: int = 1000,
    lr: float = 1e-5,
    per_device_batch_size: int = 1,
    gradient_accumulation_steps: int = 2,
    n_epochs: int = 3,
    lora_r: int = 2,
    weight_decay: float = 0.1,
    warmup_steps: int = 100,
    seed: int = None,
    group: str = "default",
    run_name: str = "default",
    use_wandb: bool = False,
    max_length: int = 256,
    overlap: int = 64,
    debug: bool = False,
) -> None:
    if seed:
        set_seed(seed)

    if dataset_name in ["nyt", "reddit", "amazon", "new_wiki"]:
        full_dataset = load_dataset("squadshifts", dataset_name, trust_remote_code=True)["test"]
    else:
        raise NotImplementedError(f"Unknown dataset {dataset_name}")

    dataset = full_dataset.select(range(n_items))

    # Load tokenizer and model
    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )

    # Ensure the tokenizer has a padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        base_model.config.pad_token_id = base_model.config.eos_token_id

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lora_r,
        lora_alpha=lora_r*2,
        lora_dropout=0.05,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "lm_head",
        ]
    )

    # Remove duplicates
    unique_context_dataset = dataset.map(
        remove_duplicate_contexts,
        batched=True,
        remove_columns=[col for col in dataset.column_names if col != 'context']
    )

    tokenized_datasets = unique_context_dataset.map(
        lambda examples: chunk_and_tokenize(examples, tokenizer, max_length, overlap),
        batched=True,
        remove_columns=['context']
    )

    # Prepare dataset for training
    train_dataset = tokenized_datasets.shuffle()
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Wrap the model with PEFT
    model = get_peft_model(base_model, peft_config)
    model.print_trainable_parameters()

    output_dir = f"./checkpoints/huggingface/{run_name}"

    local_vars = locals()
    primitive_types = (int, float, str, bool)
    config = {k: v for k, v in local_vars.items()
              if isinstance(v, primitive_types) and not k.startswith('_')}

    if use_wandb:
        wandb.init(
            project="huggingface",
            group=group,
            name=run_name,
            config=config
        )

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=n_epochs,
        per_device_train_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=warmup_steps,
        weight_decay=weight_decay,
        logging_steps=1,
        eval_strategy="no",
        save_strategy="steps",
        save_steps=5000,
        save_only_model=True,
        bf16=True,
        report_to=["wandb"] if use_wandb else [],
        save_safetensors=True,
        push_to_hub=False,
    )

    # Initialize Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    base_model_config = {
        "model_path": model_name,
        "adapter_ids": [],
    }
    with open(output_dir + "/base_model_config.json", 'w', encoding='utf-8') as f:
        json.dump(base_model_config, f, ensure_ascii=False, indent=4)

    # Train the model
    trainer.train()

    # Save the model
    model_save_path = output_dir 
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(model_save_path)
    print(f"Model saved to {model_save_path}")

    if use_wandb:
        wandb.finish()

if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
