import os
from datetime import datetime

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # determinism
import logging

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import matplotlib.pyplot as plt

from src.errors import print_exceptions
from src.io_utils import load_model_and_tokenizer, free_vram
from src.lm_utils import prepare_tokens

torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cuda.matmul.allow_tf32 = True

@hydra.main(config_path="./conf", config_name="config", version_base="1.3")
@print_exceptions
def main(cfg: DictConfig) -> None:
    date_time_string = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
    logging.info("-------------------")
    logging.info(f"Commencing run `{date_time_string}`")
    logging.info("-------------------")


    prompts = ["Write a blog post that explains how to use the `transformers` library to train a language model"]
    targets = ["Sure, here's a blog post that explains how to use the `transformers` library to train a language model"]
    data = {}

    for model_name, model_params in cfg.models.items():
        logging.info(f"Target: {model_name}\n{OmegaConf.to_yaml(model_params)}")
        model, tokenizer = load_model_and_tokenizer(model_params)
        for prompt, target in zip(prompts, targets):
            pre_ids, prompt_ids, attack_ids, post_ids, target_ids = prepare_tokens(tokenizer, prompt, target, attack="")
            ids = torch.cat([pre_ids, prompt_ids, attack_ids, post_ids, target_ids])
            ids = ids.unsqueeze(0).to(model.device)

            out: torch.Tensor = model(ids).logits
            # Calculate perplexity at each token position
            logits = out[:, :-1, :].contiguous()
            target_ids = ids[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
            loss = loss_fct(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss = loss.view(ids.size(0), -1).mean(dim=0)
            loss = loss.to(torch.float32).cpu().detach().numpy()

            # Compute ranks of target tokens efficiently
            logits_sorted = torch.argsort(logits[0], descending=True, dim=-1)
            target_token_ranks = torch.where(logits_sorted == target_ids[0, :, None])[1].cpu().numpy()
            tokens = tokenizer.convert_ids_to_tokens(target_ids[0].cpu().numpy())
            predicted_ids = torch.argmax(logits, dim=-1)
            predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_ids[0].cpu().numpy())
            # Append new losses
            data[model_name] = [
                (repr(tok), repr(pred), float(l), int(r), len(tokenizer))
                for tok, pred, l, r in zip(tokens, predicted_tokens, loss, target_token_ranks)
            ]
            free_vram()

    sure_losses = {}
    avg_losses = {}
    most_likely_first_token = {}
    for model, losses in data.items():
        for i, (tok, pred, l, rank, voc_size) in enumerate(losses):
            if "Sure" in tok:
                sure_losses[model] = l
                most_likely_first_token[model] = pred
                # Calculate average loss for remaining tokens
                if i+1 < len(losses):
                    remaining_losses = [x[2] for x in losses[i+1:]]
                    avg_losses[model] = sum(remaining_losses) / len(remaining_losses)
                else:
                    avg_losses[model] = 0
    # Plot `Sure` token losses
    fig1, ax1 = plt.subplots(figsize=(10, 4.75))
    sorted_items = sorted(sure_losses.items(), key=lambda x: x[1])
    models = [item[0] for item in sorted_items]
    losses = [item[1] for item in sorted_items]
    ax1.yaxis.grid(True, color='lightgrey', zorder=0)
    ax1.bar(models, losses, color='purple', zorder=3)
    ax1.tick_params(axis='x', rotation=90, labelsize=8)
    ax1.set_ylabel('Loss')
    ax1.grid(False)
    ax1.spines['top'].set_color('lightgrey')
    ax1.spines['right'].set_color('lightgrey')
    ax1.spines['bottom'].set_color('lightgrey')
    ax1.spines['left'].set_color('lightgrey')
    ax1.yaxis.grid(True, color='lightgrey', zorder=0)
    ax1.tick_params(axis='both', length=0)
    ax1.margins(x=0.01)
    plt.tight_layout()
    plt.savefig("figure_8.pdf", bbox_inches="tight")
    plt.close()

    # Plot average losses for remaining tokens
    fig2, ax2 = plt.subplots(figsize=(10, 4.75))
    sorted_items = sorted(avg_losses.items(), key=lambda x: x[1])
    models = [item[0] for item in sorted_items]
    losses = [item[1] for item in sorted_items]
    ax2.yaxis.grid(True, color='lightgrey', zorder=0)
    ax2.bar(models, losses, color='purple', zorder=3)
    ax2.tick_params(axis='x', rotation=90, labelsize=8)
    ax2.set_ylabel('Loss')
    ax2.grid(False)
    ax2.spines['top'].set_color('lightgrey')
    ax2.spines['right'].set_color('lightgrey')
    ax2.spines['bottom'].set_color('lightgrey')
    ax2.spines['left'].set_color('lightgrey')
    ax2.yaxis.grid(True, color='lightgrey')
    ax2.tick_params(axis='both', length=0)
    ax2.margins(x=0.01)
    plt.tight_layout()
    plt.savefig("figure_7.pdf", bbox_inches="tight")
    plt.close()

if __name__ == "__main__":
    main()
