from collections import Counter, deque
from datetime import datetime
import json
import math
import pickle
import random

from more_itertools import batched
from transformers import AutoTokenizer, AutoModel
import torch
from datasets import load_dataset
from vector_quantize_pytorch import FSQ
import os
import shutil
from torch.nn import functional as F

os.environ["MLFLOW_EXPERIMENT_NAME"] = "QCR"
os.environ["MLFLOW_TRACKING_URI"] = "http://10.128.0.103:8080"

import sys
import mlflow

from modeling_autoencoder_v2 import Autoencoder, AutoencoderOutput
from train_utils import StreamingTFIDF, StreamingTFIDFResults, mask_words


def get_dataset(dataset_name):
    doc_id_2_queries = None
    if dataset_name == "golden_dataset":
        jsonl_dataset = "msmarco-passage-train.jsonl"
        train_dataset = load_dataset("json", data_files=jsonl_dataset)["train"]
    elif dataset_name == "golden_dataset_with_hard_negatives":
        jsonl_dataset_with_hard_negatives = "msmarco-passage-train-with-negs.jsonl"
        train_dataset = load_dataset(
            "json", data_files=jsonl_dataset_with_hard_negatives
        )["train"]
    elif dataset_name == "minder_queries":
        train_dataset = load_dataset("irds/msmarco-passage", "docs")
        DOC_ID_2_QUERIES = "msmarco_raw/minder/pid2query_msmarco.pkl"
        with open(DOC_ID_2_QUERIES, "rb") as rb:
            doc_id_2_queries = pickle.load(rb)
    elif dataset_name == "gemma_queries":
        train_dataset = load_dataset("irds/msmarco-passage", "docs")
        DOC_ID_2_QUERIES = "doc_id_to_query_gemma.json"
        with open(DOC_ID_2_QUERIES, "rb") as rb:
            doc_id_2_queries = json.load(rb)
    elif dataset_name == "golden_dataset_with_hard_negatives_and_gemma_queries":
        jsonl_dataset_with_hard_negatives = "msmarco-passage-train-with-negs.jsonl"
        train_dataset = load_dataset(
            "json", data_files=jsonl_dataset_with_hard_negatives
        )["train"]
        DOC_ID_2_QUERIES = "doc_id_to_query_gemma.json"
        with open(DOC_ID_2_QUERIES, "rb") as rb:
            doc_id_2_queries = json.load(rb)
    return train_dataset, doc_id_2_queries


def create_mask(non_padded, max_length):
    batch_size = len(non_padded)
    # Initialize the mask with zeros
    mask = torch.zeros(
        batch_size, max_length, dtype=torch.float32, device=non_padded.device
    )

    # Set the first non_padded[i] elements in each row to 1
    for i, length in enumerate(non_padded):
        mask[i, :length] = 1

    return mask


def colbert_loss(
    start_from,
    state_queries,
    state_documents,
    distance,
):
    # normalize
    if distance == "cosine":
        state_queries = F.normalize(state_queries, p=2, dim=-1, eps=1e-6)
        state_documents = F.normalize(state_documents, p=2, dim=-1, eps=1e-6)
        scores = torch.einsum("bth,BTh->btBT", state_queries, state_documents)
    elif distance == "mse":
        scores = -torch.cdist(
            state_queries.reshape(1, -1, state_queries.size(-1)),
            state_documents.reshape(1, -1, state_documents.size(-1)),
        ).view(
            state_queries.size(0),
            state_queries.size(1),
            state_documents.size(0),
            state_documents.size(1),
        )

    # scores = scores - scores.detach() + scores.clamp_min(0.20).detach()

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        if LOSS == "simclr":
            scores = scores.reshape(scores.size(0) * scores.size(1), -1)
            labels = torch.arange(scores.size(0), device=scores.device)
            loss = torch.nn.functional.cross_entropy(
                scores,
                start_from + labels,
            )
            return loss
        else:
            scores_red1 = scores.max(dim=-1).values  # btB
            scores_red2 = scores_red1.sum(dim=1)

    loss = torch.nn.functional.cross_entropy(
        scores_red2,
        start_from + torch.arange(state_queries.size(0)).to(scores_red2.device),
    )
    return loss


def entropy_loss(
    autoencoder,
    state_documents_codes,
    codebook_vectors,
    documents_latent_indices,
    hard=False,
):
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits = -torch.cdist(
            state_documents_codes.view(1, -1, state_documents_codes.size(-1)),
            codebook_vectors.unsqueeze(0),
        ).view(
            state_documents_codes.size(0),
            state_documents_codes.size(1),
            codebook_vectors.size(0),
        )
        logits /= ENTROPY_SOFTMAX
    pdist = logits.log_softmax(dim=-1).exp()  # BTc
    if hard:
        pdist_hard = torch.nn.functional.one_hot(
            logits.argmax(dim=-1), num_classes=autoencoder.quantizer.codebook_size
        ).to(pdist.dtype)
        pdist = 0.9 * ((pdist - pdist.detach()) + pdist_hard) + 0.1 * pdist
    pdist = pdist.mean(1)  # Bc
    pdist = pdist.mean(0)

    norm_entropy = -torch.sum(
        pdist * pdist.masked_fill_(pdist <= 0.0, 1e-6).log()
    ) / math.log(autoencoder.quantizer.codebook_size)
    neg_norm_entropy = 1.0 - norm_entropy
    assert (
        neg_norm_entropy >= 0.0 and neg_norm_entropy <= 1.0
    ), f"neg_norm_entropy: {neg_norm_entropy}"
    return neg_norm_entropy


def check_hard_negs(dataset):
    return "with_hard_negatives" in dataset


if __name__ == "__main__":
    model_id = "castorini/tct_colbert-msmarco"
    # model_id = "facebook/bart-large"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    saved_checkpoints = deque()

    SAVE_MAX_CHECKPOINTS = 5
    SAVE_PATH = f"home/runs/autencoder-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    BIG_BATCH_SIZE = 2000
    MINI_BATCH_SIZE = 20
    ENTROPY_LOSS_ALPHA = 2000.0
    CONTR_LOSS_ALPHA = 1.0
    ENTROPY_SOFTMAX = 0.2
    LR = 1e-5
    EPOCHS = 5
    DATASET = "silver_queries"
    MASK_DOCS_PERC = 0.0
    MAX_LEN = 256
    REP_LEN = 250
    COLBERT_DISTANCE = "mse"
    HARD_ENTROPY = True
    LOSS = "colbert"
    SAVE_EVERY = 20
    START_CHECKPOINT = None
    # START_CHECKPOINT = "/home/runs/autencoder-2024-09-30_14-20-44/checkpoint-2900"
    intermediate_size = 5
    level_dim = 5
    train_dataset, doc_id_to_queries = get_dataset(DATASET)
    print("len dataset: ", len(train_dataset))
    print("updates 1 epoch: ", len(train_dataset) / BIG_BATCH_SIZE)
    if START_CHECKPOINT:
        autoencoder = Autoencoder.load_checkpoint(
            path=START_CHECKPOINT,
        ).to("cuda")
        autoencoder.length = REP_LEN
    else:
        autoencoder = Autoencoder.init_from_pretrained(
            model_name_or_path=model_id,
            levels=[level_dim] * intermediate_size,
            hidden_size=intermediate_size,
            length=REP_LEN,
        ).to("cuda")
    autoencoder.add_special_tokens()

    optimizer = torch.optim.Adam(
        list(autoencoder.parameters()),
        lr=LR,
    )

    half_l = (autoencoder.quantizer._levels - 1) * (1 + 1e-3) / 2
    codebook_vectors = autoencoder.quantizer.indices_to_codes(
        torch.arange(
            autoencoder.quantizer.codebook_size,
            device="cuda",
            dtype=torch.long,
        )
    ) * (autoencoder.quantizer._levels // 2)

    mlf_context_manager = mlflow.start_run(log_system_metrics=True)
    token_counter = StreamingTFIDF(
        history_size=200,
    )
    with mlf_context_manager:
        mlflow.log_params(
            {
                "run": "debug_colbert",
                "save_path": SAVE_PATH,
                "BATCH_SIZE": BIG_BATCH_SIZE,
                "GRAD_CACHE_BATCH_SIZE": MINI_BATCH_SIZE,
                "LR": LR,
                "entropy_loss_alpha": ENTROPY_LOSS_ALPHA,
                "contr_loss_alpha": CONTR_LOSS_ALPHA,
                "ENTROPY_SOFTMAX": ENTROPY_SOFTMAX,
                "EPOCHS": EPOCHS,
                "use_queries": DATASET,
                "level_quant": level_dim,
                "levels_length": intermediate_size,
                "MASK_DOCS_PERC": MASK_DOCS_PERC,
                "MAX_LEN": MAX_LEN,
                "COLBERT_DISTANCE": COLBERT_DISTANCE,
                "REP_LEN": REP_LEN,
                "LOSS": LOSS,
                "SAVE_EVERY": SAVE_EVERY,
                "HARD_ENTROPY": HARD_ENTROPY,
                "START_CHECKPOINT": START_CHECKPOINT,
            }
        )
        cur_steps_total = 0
        for epoch in range(EPOCHS):
            for i, batch in enumerate(batched(train_dataset.shuffle(), BIG_BATCH_SIZE)):
                doc_ids = [doc["doc_id"] for doc in batch]
                if doc_id_to_queries:
                    queries = []
                    for doc_id in doc_ids:
                        # query = random.choice(doc_id_to_queries[doc_id])
                        query_split = doc_id_to_queries[doc_id].split(",")
                        # take random subset of queries, of random length
                        query = " ".join(
                            random.sample(
                                query_split,
                                random.randint(1, len(query_split)),
                            )
                        )
                        queries.append(query)
                else:
                    queries = [sample["query_text"] for sample in batch]
                documents = [sample["text"] for sample in batch]
                if check_hard_negs(DATASET):
                    hard_neg_docs = [
                        random.choice(sample["hard_negs_docs"]) for sample in batch
                    ]
                documents = mask_words(documents, MASK_DOCS_PERC, autoencoder)
                batch_queries = autoencoder.tokenize(
                    queries,
                    max_length=MAX_LEN,
                    pad_to_max_length=True,
                ).to("cuda")
                batch_documents = autoencoder.tokenize(
                    documents,
                    max_length=MAX_LEN,
                    pad_to_max_length=True,
                ).to("cuda")
                if check_hard_negs(DATASET):
                    batch_hard_negs = autoencoder.tokenize(
                        hard_neg_docs,
                        max_length=MAX_LEN,
                        pad_to_max_length=True,
                    ).to("cuda")

                optimizer.zero_grad()
                contr_loss_grad_cache = 0
                state_queries_cache = []
                state_doc_cache = []
                state_hard_negs_cache = []
                all_docs_indices = []
                all_query_indices = []
                for minibatch_i in range(0, len(queries), MINI_BATCH_SIZE):
                    minibatch_queries = {
                        "input_ids": batch_queries["input_ids"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        "attention_mask": batch_queries["attention_mask"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                    }
                    minibatch_documents = {
                        "input_ids": batch_documents["input_ids"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        "attention_mask": batch_documents["attention_mask"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                    }
                    if check_hard_negs(DATASET):
                        minibatch_hard_negs = {
                            "input_ids": batch_hard_negs["input_ids"][
                                minibatch_i : minibatch_i + MINI_BATCH_SIZE
                            ],
                            "attention_mask": batch_hard_negs["attention_mask"][
                                minibatch_i : minibatch_i + MINI_BATCH_SIZE
                            ],
                        }

                    with torch.no_grad():
                        out_queries = autoencoder(minibatch_queries)
                        out_documents = autoencoder(minibatch_documents)
                        if check_hard_negs(DATASET):
                            out_hard_negs = autoencoder(minibatch_hard_negs)

                    state_queries = state_queries_original = out_queries.encoder_outputs
                    state_documents = state_documents_original = (
                        out_documents.encoder_outputs
                    )

                    # state_queries_up = out_queries.encoder_upprojected
                    # state_documents_up = out_documents.encoder_upprojected
                    if check_hard_negs(DATASET):
                        state_hard_negs = out_hard_negs.encoder_outputs
                        # state_hard_negs_up = out_hard_negs.encoder_upprojected

                    queries_latent_indices = out_queries.quantizer_outputs[1]
                    documents_latent_indices = out_documents.quantizer_outputs[1]

                    state_queries_cache.append(
                        out_queries.encoder_downprojected.tanh() * half_l
                    )
                    state_doc_cache.append(
                        out_documents.encoder_downprojected.tanh() * half_l
                    )
                    if check_hard_negs(DATASET):
                        state_hard_negs_cache.append(
                            out_hard_negs.encoder_downprojected.tanh() * half_l
                        )
                    all_docs_indices.append(documents_latent_indices)
                    all_query_indices.append(queries_latent_indices)
                state_queries_cache = torch.cat(state_queries_cache, dim=0)
                state_doc_cache = torch.cat(state_doc_cache, dim=0)
                if check_hard_negs(DATASET):
                    state_hard_negs_cache = torch.cat(state_hard_negs_cache, dim=0)
                    state_doc_cache = torch.cat(
                        (state_doc_cache, state_hard_negs_cache), dim=0
                    )
                all_docs_indices = torch.cat(all_docs_indices, dim=0)
                all_query_indices = torch.cat(all_query_indices, dim=0)
                state_queries_cache.requires_grad_(True)
                state_doc_cache.requires_grad_(True)
                state_queries_cache.retain_grad()
                state_doc_cache.retain_grad()

                for minibatch_i in range(0, len(queries), MINI_BATCH_SIZE):
                    minibatch_queries = {
                        "input_ids": batch_queries["input_ids"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        "attention_mask": batch_queries["attention_mask"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                    }
                    mask_docs = batch_documents["attention_mask"]
                    if check_hard_negs(DATASET):
                        mask_docs = torch.cat(
                            (mask_docs, batch_hard_negs["attention_mask"]), dim=0
                        )
                    loss = colbert_loss(
                        minibatch_i,
                        state_queries_cache[
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        state_doc_cache,
                        COLBERT_DISTANCE,
                    )
                    contr_loss_grad_cache += loss.item()
                    loss.backward(retain_graph=True)

                neg_norm_entropy_tot = 0
                for minibatch_i in range(0, len(queries), MINI_BATCH_SIZE):
                    minibatch_queries = {
                        "input_ids": batch_queries["input_ids"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        "attention_mask": batch_queries["attention_mask"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                    }
                    minibatch_documents = {
                        "input_ids": batch_documents["input_ids"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                        "attention_mask": batch_documents["attention_mask"][
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ],
                    }
                    out_queries_with_grad = autoencoder(minibatch_queries)
                    out_documents_with_grad = autoencoder(minibatch_documents)
                    state_queries_with_grad = out_queries_with_grad.encoder_outputs
                    state_documents_with_grad = out_documents_with_grad.encoder_outputs

                    state_queries_codes_with_grad, queries_latent_indices = (
                        out_queries_with_grad.quantizer_outputs
                    )
                    state_documents_codes_with_grad, documents_latent_indices = (
                        out_documents_with_grad.quantizer_outputs
                    )

                    queries_prequant = (
                        out_queries_with_grad.encoder_downprojected.tanh() * half_l
                    )
                    docs_prequant = (
                        out_documents_with_grad.encoder_downprojected.tanh() * half_l
                    )

                    doc_len_idx = min(
                        len(batch_documents["input_ids"]), minibatch_i + MINI_BATCH_SIZE
                    )
                    contr_loss_val = (
                        queries_prequant
                        * state_queries_cache.grad[
                            minibatch_i : minibatch_i + MINI_BATCH_SIZE
                        ]
                    ).sum() + (
                        docs_prequant * state_doc_cache.grad[minibatch_i:doc_len_idx]
                    ).sum()

                    neg_norm_entropy_docs = entropy_loss(
                        autoencoder,
                        docs_prequant,
                        codebook_vectors,
                        documents_latent_indices,
                        hard=HARD_ENTROPY,
                    )
                    neg_norm_entropy_queries = entropy_loss(
                        autoencoder,
                        queries_prequant,
                        codebook_vectors,
                        queries_latent_indices,
                        hard=HARD_ENTROPY,
                    )
                    neg_norm_entropy = (
                        neg_norm_entropy_docs + neg_norm_entropy_queries
                    ) / 2.0
                    neg_norm_entropy_tot += neg_norm_entropy.item()
                    (
                        contr_loss_val * CONTR_LOSS_ALPHA
                        + ENTROPY_LOSS_ALPHA
                        * neg_norm_entropy
                        / (BIG_BATCH_SIZE / MINI_BATCH_SIZE)
                    ).backward(retain_graph=True)
                step_stats: StreamingTFIDFResults = token_counter.step(
                    target_indices=all_docs_indices,
                    source_indices=all_query_indices,
                    source_mask=batch_queries["attention_mask"][:, :REP_LEN],
                    target_mask=batch_documents["attention_mask"][:, :REP_LEN],
                )
                if i % 10 == 0:
                    try:
                        codebook_size = autoencoder.quantizer.codebook_size
                        codebook_freqs = Counter(all_docs_indices.reshape(-1).tolist())
                        codebook_entpy = -sum(
                            [
                                (p := freq / sum(codebook_freqs.values())) * math.log(p)
                                for _, freq in codebook_freqs.items()
                            ]
                        )
                        codebook_usage = (
                            len(set(codebook_freqs)) / codebook_size * 100.0
                        )
                        mlflow.log_metrics(
                            {
                                "loss_contr": contr_loss_grad_cache
                                / (BIG_BATCH_SIZE / MINI_BATCH_SIZE),
                                "loss_entropy": neg_norm_entropy_tot
                                / (BIG_BATCH_SIZE / MINI_BATCH_SIZE),
                                "tfidf": step_stats.positive_tf_idf,
                                "tfidf_neg": step_stats.max_negative_tf_idf,
                                "tfidf_diff": step_stats.positive_tf_idf
                                - step_stats.max_negative_tf_idf,
                                "tfidf_prop": step_stats.positive_tf_idf
                                / max(step_stats.max_negative_tf_idf, 1e-1),
                                "covrg": step_stats.positive_intersection_length,
                                "uniql": step_stats.target_uniq_len,
                                "freq_divergence": step_stats.freq_divergence,
                                "match_pos": step_stats.new_positive_intersection_length,
                                "match_neg": step_stats.new_negative_intersection_length,
                                "match_uniq_pos": step_stats.positive_unique_intersection_length,
                                "match_uniq_neg": step_stats.negative_unique_intersection_length,
                                "match_same_pos_pos": step_stats.positive_same_pos_intersection_length,
                                "match_same_pos_neg": step_stats.negative_same_pos_intersection_length,
                                "match_diff": step_stats.new_positive_intersection_length
                                - step_stats.new_negative_intersection_length,
                                "match_prop": step_stats.new_positive_intersection_length
                                / max(
                                    step_stats.new_negative_intersection_length, 1e-1
                                ),
                                "match_uniq_prop": step_stats.positive_unique_intersection_length
                                / max(
                                    step_stats.negative_unique_intersection_length, 1e-1
                                ),
                                "codebook_usage": codebook_usage,
                                "codebook_entropy": codebook_entpy,
                            },
                            step=cur_steps_total,
                        )
                    except Exception as e:
                        print(e)
                if i % 10 == 0:
                    print(
                        "iteration",
                        cur_steps_total,
                    )
                    print(queries[0])
                    print(documents[0])
                    print(
                        "contr",
                        contr_loss_grad_cache / (BIG_BATCH_SIZE / MINI_BATCH_SIZE),
                    )
                    print(
                        "entro",
                        neg_norm_entropy_tot / (BIG_BATCH_SIZE / MINI_BATCH_SIZE),
                        f"{codebook_usage}%",
                    )
                    print(
                        all_query_indices[0][
                            batch_queries["attention_mask"][:, :REP_LEN][0] == 1
                        ]
                    )
                    print(
                        all_docs_indices[0][
                            batch_documents["attention_mask"][:, :REP_LEN][0] == 1
                        ]
                    )
                    pos_overlap = 0
                    neg_overlap = []
                    for batch_idx in range(all_query_indices.size(0)):
                        query = all_query_indices[batch_idx][
                            batch_queries["attention_mask"][:, :REP_LEN][batch_idx] == 1
                        ]
                        doc = all_docs_indices[batch_idx][
                            batch_documents["attention_mask"][:, :REP_LEN][batch_idx]
                            == 1
                        ]
                        doc_neg = all_docs_indices[
                            (batch_idx + 1) % all_docs_indices.size(0)
                        ][
                            batch_documents["attention_mask"][:, :REP_LEN][
                                (batch_idx + 1) % all_docs_indices.size(0)
                            ]
                            == 1
                        ]
                        query_counter = Counter(query.reshape(-1).tolist())
                        doc_counter = Counter(doc.reshape(-1).tolist())
                        doc_counter_neg = Counter(doc_neg.reshape(-1).tolist())
                        pos_overlap += float(
                            sum((query_counter & doc_counter).values())
                        )
                        neg_overlap.append(
                            float(sum((query_counter & doc_counter_neg).values()))
                        )
                    print("pos_overlap", pos_overlap / all_query_indices.size(0))
                    print("neg overlap", sum(neg_overlap) / all_query_indices.size(0))
                    print("neg overlap max", max(neg_overlap))

                for param in list(autoencoder.parameters()):
                    torch.nn.utils.clip_grad_norm_(param, 2.0)
                    
                optimizer.step()
                cur_steps_total += 1

                if cur_steps_total % SAVE_EVERY == 0:
                    print("Saving")

                    checkpoint_path = os.path.join(
                        SAVE_PATH,
                        f"checkpoint-{cur_steps_total}",
                    )
                    print(checkpoint_path)
                    autoencoder.save_checkpoint(
                        checkpoint_path,
                    )
                    torch.save(
                        optimizer.state_dict(), checkpoint_path + "/optimizer.pt"
                    )
                    saved_checkpoints.append(checkpoint_path)

                    while (SAVE_MAX_CHECKPOINTS > 0) and (
                        len(saved_checkpoints) > SAVE_MAX_CHECKPOINTS
                    ):
                        old_checkpoint = saved_checkpoints.popleft()
                        shutil.rmtree(old_checkpoint)
