import os
from argparse import ArgumentParser
from dotenv import load_dotenv

import wandb
import numpy as np
import torch
import transformers
from transformers import DataCollatorForLanguageModeling
from loguru import logger
import json

import llm_fairness

from huggingface_hub import HfApi, Repository
from transformers import TrainerCallback, Trainer
from collections import defaultdict
import torch

import shutil
import gc

gc.collect()

load_dotenv(dotenv_path="vars")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16 = torch.cuda.is_available()

parser = ArgumentParser()
parser.add_argument(
    "--backbone", choices=llm_fairness.MODELS, default="apple/OpenELM-270M"
)
parser.add_argument(
    "--dataset", default="imdb", choices=llm_fairness.data.DATASET_REGISTRY
)
parser.add_argument("--max-training-steps", type=int, default=10000)
parser.add_argument("--per-device-batch-size", type=int, default=64)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--logging-steps", type=int, default=100)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--vocab-size", type=int, default=32000)
parser.add_argument("--max-length", type=int, default=128)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--save-last-ckpt-to-disk", default=False)
parser.add_argument("--save-to-hf-hub", default=False)
parser.add_argument("--generate-samples", type=int, default=100)
args = parser.parse_args()

wandb_api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=wandb_api_key, relogin=True)

wandb.init(
    project="ANONYMOUS",
    entity="ANONYMOUS",
    config={
        "args." + k: v for k, v in vars(args).items()
    },  # add parsed arguments into config
)


class LogitStoreTrainer(Trainer):
    def __init__(self, *args, tokenizer, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = tokenizer

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        logits = outputs.get("logits")
        labels = inputs.get("labels")
        loss = super().compute_loss(model, inputs)
        if return_outputs:
            return loss, outputs
        # This is to store logits and labels for callback.
        # Warning. This can cause memory issues on small GPUs.
        self.state.logits = logits
        self.state.labels = labels
        return loss
    
    def evaluate(self):
        eval_dataloader = self.get_eval_dataloader()
        token_metrics = defaultdict(lambda: {"loss": 0.0, "correct": 0, "total": 0})

        total_batches = len(eval_dataloader)
        logger.info(f"Evaluating on {total_batches} batches")

        for index, inputs in enumerate(eval_dataloader):
            print(f"\rProcessing batch {index+1}/{total_batches}", end="", flush=True)
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                labels = inputs["labels"]

                vocab_size = logits.size(-1)

                # Shift logits and labels to align with next token prediction
                shift_logits = logits[..., :-1, :].contiguous().view(-1, vocab_size)
                shift_labels = labels[..., 1:].contiguous().view(-1)
                preds = shift_logits.argmax(dim=-1)

                # Compute per-token loss
                loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                loss = loss_fct(shift_logits, shift_labels)

                # Aggregate metrics for each token
                for token_idx, token_id in enumerate(shift_labels.cpu().numpy()):
                    if token_id == self.tokenizer.pad_token_id:
                        continue  # Skip padding tokens
                    token_loss = loss[token_idx].item()
                    correct = preds.view(-1)[token_idx].item() == token_id
                    token_metrics[token_id]["loss"] += token_loss
                    token_metrics[token_id]["correct"] += correct
                    token_metrics[token_id]["total"] += 1

                # Clear memory after each batch
                del logits, labels, shift_logits, shift_labels, preds, loss
                torch.cuda.empty_cache()
                gc.collect()

        print()  # New line after progress indicator

        # Convert token metrics to loggable format
        def convert_np(obj):
            if isinstance(obj, np.int64):
                return int(obj)
            if isinstance(obj, dict):
                return {convert_np(k): convert_np(v) for k, v in obj.items()}
            return obj

        token_metrics_python_keys = convert_np(token_metrics)
        token_metrics_json = json.dumps(token_metrics_python_keys)

        # Log per-token metrics
        wandb.log(
            {"eval_token_metrics": token_metrics_json, "eval_step": index},
            commit=True,
        )

class PerTokenMetricsCallback(TrainerCallback):
    # our custom callback to keep track of per-token metrics
    def __init__(self, tokenizer, logging_steps):
        self.tokenizer = tokenizer
        self.token_metrics = defaultdict(
            lambda: {"loss": 0.0, "correct": 0, "total": 0}
        )

    def on_step_end(self, args, state, control, **kwargs):
        logits = state.logits
        labels = state.labels
        vocab_size = logits.size(-1)

        # next token target is the input token shifted by one position
        shift_logits = logits[..., :-1, :].contiguous().view(-1, vocab_size)
        shift_labels = labels[..., 1:].contiguous().view(-1)
        preds = shift_logits.argmax(dim=-1)

        # compute loss per token!
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(shift_logits, shift_labels)

        for token_idx, token_id in enumerate(shift_labels.cpu().numpy()):
            if token_id == self.tokenizer.pad_token_id:
                continue  # skip padding tokens
            token_loss = loss[token_idx].item()
            correct = preds.view(-1)[token_idx].item() == token_id
            self.token_metrics[token_id]["loss"] += token_loss
            self.token_metrics[token_id]["correct"] += correct
            self.token_metrics[token_id]["total"] += 1

        del state.logits
        del state.labels
        gc.collect()

    def on_log(self, args, state, control, **kwargs):
        def convert_np(obj):
            if isinstance(obj, np.int64):
                return int(obj)
            if isinstance(obj, dict):
                return {convert_np(k): convert_np(v) for k, v in obj.items()}
            return obj

        token_metrics_python_keys = convert_np(self.token_metrics)
        token_metrics_json = json.dumps(token_metrics_python_keys)

        wandb.log(
            {"token_metrics": token_metrics_json, "step": state.global_step},
            commit=True,
        )

        self.token_metrics = defaultdict(
            lambda: {"loss": 0.0, "correct": 0, "total": 0}
        )


def prepare_next_token_dataset(dataset, tokenizer, max_length):
    dataset = dataset.map(
        lambda x: tokenizer(
            x["text"],
            truncation=True,
            padding="max_length",
            max_length=max_length,
        ),
        batched=True,
        remove_columns=dataset.column_names,
    )
    dataset = dataset.add_column("labels", dataset["input_ids"])
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    return dataset


# load raw data and train tokenizer.
data = llm_fairness.data.from_name(args.dataset)
train_dataset, test_dataset = data["train"], data["test"]
tokenizer = llm_fairness.tokenizer.from_data(
    train_dataset, variant="BPE", vocab_size=args.vocab_size
)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)  # Might delete.
logger.info(f"BPE tokenizer trained! Vocab size: {len(tokenizer.vocab)}")

# prepare dataset for training.
train_dataset = prepare_next_token_dataset(train_dataset, tokenizer, args.max_length)
test_dataset = prepare_next_token_dataset(test_dataset, tokenizer, args.max_length)

# load model.
torch.backends.cuda.enable_mem_efficient_sdp(True)
model = llm_fairness.models.from_name(
    args.backbone,
    tokenizer=tokenizer,
    pretrained=False,
    task="lm",
    max_length=args.max_length,
)
model.requires_grad_(True)
model.config.use_cache = False
model.gradient_checkpointing_enable()
logger.info(f"Model loaded: {args.backbone}")

# prepare optim.
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
    params, lr=args.learning_rate, weight_decay=args.weight_decay
)
scheduler = transformers.get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=10, num_training_steps=args.max_training_steps
)

# prepare for training.
training_args = transformers.TrainingArguments(
    output_dir="/tmp",  # Don't save checkpoints.
    save_strategy="no",
    save_steps=0,
    max_steps=args.max_training_steps,
    per_device_train_batch_size=args.per_device_batch_size,
    per_device_eval_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    gradient_checkpointing=True,
    overwrite_output_dir=True,
    weight_decay=args.weight_decay,
    learning_rate=args.learning_rate,
    logging_steps=args.logging_steps,
    fp16=(fp16 and "mistralai" not in args.backbone),
    seed=args.seed
)

per_token_metrics_callback = PerTokenMetricsCallback(
    tokenizer=tokenizer, logging_steps=args.logging_steps
)

trainer = LogitStoreTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    optimizers=(optimizer, scheduler),
    callbacks=[per_token_metrics_callback],
    tokenizer=tokenizer
)

# run training.
logger.info("Ready to start training...")
trainer.train()

# also run evaluation on test set.
logger.info("Evaluating on test set...")
trainer.evaluate()


if args.generate_samples > 0:
    # generate a few samples from the model.
    logger.info(f"Generating {args.generate_samples} samples...")
    model.eval()
    model.to(device)

    prompt = "This movie is"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # Remove token_type_ids if present
    inputs.pop('token_type_ids', None)
    # Move inputs to the correct device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    generated_samples = []
    for _ in range(args.generate_samples):
        try:
            with torch.no_grad():
                output = model.generate(
                    **inputs,
                    max_length=100,
                    num_return_sequences=1,
                    do_sample=True,
                    top_k=50,
                    top_p=0.95,
                    temperature=0.7,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
            
            generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
            generated_samples.append(generated_text)
            logger.info(f"Generated sample: {generated_text}")
        except Exception as e:
            logger.error(f"Error during generation: {e}")
            generated_samples.append(f"Error during generation: {e}")
            continue

    # Create the samples file with the wandb run id
    os.makedirs("results/samples", exist_ok=True)    
    samples_file = f"results/samples/{wandb.run.id}"
    with open(samples_file, "w") as f:
        for sample in generated_samples:
            f.write(f"{sample}\n\n")

if args.save_to_hf_hub:
    try:
        # save model to hf-hub.
        back = args.backbone.split("/")[-1].lower()
        modname = (
            back
            + "-"
            + args.dataset.lower()
            + "-wd"
            + str(args.weight_decay)
            + "-seed"
            + str(args.seed)
        )
        organization = "llm-wd-fairness"
        repo_id = f"{organization}/{modname}"

        api = HfApi()
        api.create_repo(repo_id=repo_id, private=True, repo_type="model", exist_ok=True)

        repo_path = f"./{modname}"
        repo = Repository(local_dir=repo_path, clone_from=repo_id)
        tokenizer.save_pretrained(repo_path)
        model.save_pretrained(repo_path)  #  type: ignore

        repo.git_add()
        repo.git_commit("push model to hf-hub")
        repo.git_push()

        shutil.rmtree(repo_path)
        logger.info(f"Model saved to hf-hub: {repo_id}")

    except Exception as e:
        logger.error(f"Could not save model to hf-hub: {e}")

if args.save_last_ckpt_to_disk:
    try:
        # save model to disk.
        back = args.backbone.split("/")[-1].lower()
        modname = (
            back
            + "-"
            + args.dataset.lower()
            + "-wd"
            + str(args.weight_decay)
            + "-seed"
            + str(args.seed)
        )
        model_path = f"./models/{modname}"
        os.makedirs(model_path, exist_ok=True)
        tokenizer.save_pretrained(model_path)
        model.save_pretrained(model_path)  #  type: ignore
        logger.info(f"Model saved to disk: {model_path}")

    except Exception as e:
        logger.error(f"Could not save model to disk: {e}")

# finish wandb.
wandb.finish()
