import logging
import os
import time
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict

from torch.optim import Adam
from tqdm import tqdm
from train_subgraphrag import RetrieverDataset, create_logger, log_step, set_seed
from transformers import HfArgumentParser, TrainingArguments
from transformers.tokenization_utils_base import (
    PaddingStrategy,
    PreTrainedTokenizerBase,
)

import wandb
from llm_graph_walk import graph, text_encoder
from llm_graph_walk.sample_sr import END_REL, SampleSubgraphSR

logger = logging.getLogger(__name__)


@dataclass
class OurDataCollatorWithPadding:
    tokenizer: PreTrainedTokenizerBase
    mlm_probability: float
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    mlm: bool = True

    def __call__(
        self,
        features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]],
    ) -> Dict[str, torch.Tensor]:
        special_keys = [
            "input_ids",
            "attention_mask",
            "token_type_ids",
            "mlm_input_ids",
            "mlm_labels",
        ]
        bs = len(features)
        if bs > 0:
            num_sent = len(features[0]["input_ids"])
        else:
            return
        flat_features = []
        for feature in features:
            for i in range(num_sent):
                flat_features.append(
                    {
                        k: feature[k][i] if k in special_keys else feature[k]
                        for k in feature
                    }
                )

        batch = self.tokenizer.pad(
            flat_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        if self.mlm:
            batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(
                batch["input_ids"]
            )

        batch = {
            k: (
                batch[k].view(bs, num_sent, -1)
                if k in special_keys
                else batch[k].view(bs, num_sent, -1)[:, 0]
            )
            for k in batch
        }

        if "label" in batch:
            batch["labels"] = batch["label"]
            del batch["label"]
        if "label_ids" in batch:
            batch["labels"] = batch["label_ids"]
            del batch["label_ids"]

        return batch

    def mask_tokens(
        self,
        inputs: torch.Tensor,
        special_tokens_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True
                )
                for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = (
            torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        )
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
            self.tokenizer.mask_token
        )

        # 10% of the time, we replace masked input tokens with random word
        indices_random = (
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
            & masked_indices
            & ~indices_replaced
        )
        random_words = torch.randint(
            len(self.tokenizer), labels.shape, dtype=torch.long
        )
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels


@torch.no_grad()
def run_validation(data_loader, model, valid_k_list):
    metric_dict = defaultdict(list)
    eval_start = time.time()
    for sample in tqdm(data_loader):
        pred_triples = torch.tensor(
            model(query=sample[0], seed_node_id=sample[5].numpy().tolist())[0]
        )
        if pred_triples.shape[0] == 0:
            continue
        # Triple ranking
        target_triple_ids = sample[-2].to(pred_triples.device).nonzero()
        target_triples = torch.concat(
            [
                torch.tensor(x)[target_triple_ids]
                for x in [sample[1], sample[2], sample[3]]
            ],
            dim=1,
        )
        num_correct_triples = target_triples.shape[0]
        num_retrieved_triples = pred_triples.shape[0]
        metric_dict["num_retrieved_triples"] = num_retrieved_triples
        ranks = torch.where(
            (pred_triples.unsqueeze(0) == target_triples.unsqueeze(1)).all(dim=(-1)),
            torch.arange(1, pred_triples.shape[0] + 1, device=target_triple_ids.device),
            torch.inf,
        )
        rank = ranks.min(dim=-1)[0].cpu()
        metric_dict["gt_triple_mrr"].append(
            torch.mean(
                1.0
                / torch.maximum(
                    torch.tensor(1.0),
                    rank - num_correct_triples + 1,
                )
            ).item()
        )

        a_entity_id_list = sample[-1].unsqueeze(-1)
        for k in valid_k_list:
            true_positives = (rank < k).sum().item()
            recall = true_positives / num_correct_triples
            precision = true_positives / num_retrieved_triples
            if precision + recall == 0.0:
                f1_score = 0.0
            else:
                f1_score = 2 * precision * recall / (precision + recall)
            metric_dict[f"gt_triple_recall@{k}"].append(recall)
            metric_dict[f"gt_triple_precision@{k}"].append(precision)
            metric_dict[f"gt_triple_f1@{k}"].append(f1_score)
            retrieved_triples = pred_triples[:k].cpu()
            recall_answ = torch.any(
                torch.concat([retrieved_triples[:, 0], retrieved_triples[:, 2]])
                == a_entity_id_list,
                dim=-1,
            )
            metric_dict[f"answer_recall@{k}"].append(
                recall_answ.numpy().sum().item() / len(a_entity_id_list)
            )

    for key, val in metric_dict.items():
        metric_dict[key] = np.mean(val).item()

    metric_dict["eval_duration"] = time.time() - eval_start
    return metric_dict


def main(model_args, data_args, training_args):
    args = {}
    for part_arg in [model_args, data_args, training_args]:
        args.update(asdict(part_arg))

    set_seed(training_args.seed)
    create_logger()
    if training_args.wandb:
        wandb.init(
            project=f"SR_training",
            config=args,
        )

    logger.info("Instantiating model")
    text_enc = text_encoder.TextEncoderSR(
        model_args.model_name_or_path,
        device=training_args.device,
    )
    tokenizer = text_enc.tokenizer

    logger.info("Loading data")
    node_labels = np.load(
        data_args.wikikg_dir + "/node_labels.npy",
        allow_pickle=True,
    )
    relation_labels = np.load(
        data_args.wikikg_dir + "/relation_labels.npy",
        allow_pickle=True,
    )
    relation_labels = np.array(relation_labels.tolist() + [END_REL])
    edge_ids = np.load(
        data_args.wikikg_dir + "/edge_ids.npy",
        allow_pickle=True,
    )
    relation_types = np.load(
        data_args.wikikg_dir + "/relation_types.npy",
        allow_pickle=True,
    )
    knowledge_graph = graph.Graph(
        edge_ids, relation_types, node_labels, relation_labels
    )

    samplefunc = SampleSubgraphSR(
        graph.KGInterfaceFromGraph(knowledge_graph), text_encoder=text_enc
    )  # for validation

    val_set = RetrieverDataset(data_args.valid_data_path)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=None, shuffle=False)

    # data_files = {}
    # if data_args.train_data_path is not None:
    #     data_files["train"] = data_args.train_data_path
    # datasets = load_dataset(
    #     "csv",
    #     data_files=data_files,
    #     cache_dir="./data/",
    #     delimiter="\t" if "tsv" in data_args.train_data_path else ",",
    # )
    datasets = []
    train_df = pd.read_csv(data_args.train_data_path)
    for _ in range(int(training_args.num_train_epochs)):
        train_df["targ"] = train_df.groupby(["0","1"])["2"].transform(lambda x: np.random.randint(1+max(x)))
        datasets.append(train_df[train_df["2"] == train_df["targ"]].sample(frac=1).reset_index(drop=True))
    
    train_dataset = pd.concat(datasets).reset_index(drop=True)
    datasets = DatasetDict({"train": Dataset.from_pandas(train_dataset.drop(["0", "1", "2", "targ"], axis=1))})

    training_args.num_train_epochs = 1

    column_names = datasets["train"].column_names
    sent0_cname = column_names[0]

    def prepare_features(examples):
        # padding = longest (default)
        #   If no sentence in the batch exceed the max length, then use
        #   the max sentence length in the batch, otherwise use the
        #   max sentence length in the argument and truncate those that
        #   exceed the max length.
        # padding = max_length (when pad_to_max_length, for pressure test)
        #   All sentences are padded/truncated to data_args.max_seq_length.
        total = len(examples[sent0_cname])
        bs = len(examples[sent0_cname])
        k = len(column_names)

        """
        # Avoid "None" fields 
        for idx in range(total):
            if examples[sent0_cname][idx] is None:
                examples[sent0_cname][idx] = " "
            if examples[sent1_cname][idx] is None:
                examples[sent1_cname][idx] = " "
        
        sentences = examples[sent0_cname] + examples[sent1_cname]

        # If hard negative exists
        if sent2_cname is not None:
            for idx in range(total):
                if examples[sent2_cname][idx] is None:
                    examples[sent2_cname][idx] = " "
            sentences += examples[sent2_cname]
        """
        sentences = []
        for column_name in column_names:
            sentences += examples[column_name]
        """
        for idx in range(bs):
            new_line = [v for k, v in examples[idx].items()]
            sentences += new_line
        """

        sent_features = tokenizer(
            sentences,
            max_length=data_args.max_seq_length,
            truncation=True,
            padding="max_length" if data_args.pad_to_max_length else False,
        )
        features = {}
        # 手动转置， 非常离谱，但是必要orz
        for key in sent_features:
            features[key] = [
                [sent_features[key][i + bs * j] for j in range(k)] for i in range(bs)
            ]
        """
        features = {}
        if sent2_cname is not None:
            for key in sent_features:
                features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)]
        else:
            for key in sent_features:
                features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
        """
        return features

    train_dataset = datasets["train"].map(
        prepare_features,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    data_collator = OurDataCollatorWithPadding(
        tokenizer, mlm_probability=data_args.mlm_probability, mlm=model_args.do_mlm
    )
    train_dl = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=training_args.per_device_train_batch_size,
        collate_fn=data_collator,
        shuffle=True,
    )

    save_dir = os.path.join(training_args.output_dir, training_args.save_path)
    os.makedirs(save_dir, exist_ok=True)

    step = -1
    partial_loss = []
    optimizer = Adam(text_enc.model.parameters(), lr=training_args.learning_rate)
    for epoch in range(1, int(training_args.num_train_epochs) + 1):
        logger.info(f"Training -- epoch {epoch}")
        for batch in tqdm(iter(train_dl)):
            bs, num_sent, seq_len = batch["input_ids"].shape
            step += 1
            batch_start = time.time()
            embeddings = text_enc(
                **{
                    k: v.view((-1, seq_len)).to(text_enc.device)
                    for k, v in batch.items()
                }
            )
            embeddings = embeddings.view(bs, num_sent, -1)
            query, targets = embeddings[:, 0:1], embeddings[:, 1:]
            cos = torch.nn.CosineSimilarity(dim=-1)
            cos_sim = cos(query, targets) / model_args.temp

            labels = torch.zeros(cos_sim.size(0)).long().to(cos_sim.device)
            loss = torch.nn.functional.cross_entropy(cos_sim, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            partial_loss.append(loss.item())
            step_duration = time.time() - batch_start

            if (
                step % training_args.logging_steps == 0
                or step % training_args.eval_steps == 0
            ):
                log_dict = {
                    "loss": np.mean(partial_loss),
                    "step_duration": step_duration,
                }
                partial_loss = []

                if step % training_args.eval_steps == 0:
                    text_enc.eval()
                    logger.info("Running validation")
                    log_dict.update(
                        run_validation(val_loader, samplefunc, training_args.valid_k)
                    )
                    text_enc.train()

                log_step(logger, log_dict, step, training_args.wandb)

        if training_args.save_path:
            state_dict = {"config": args, "model_state_dict": text_enc.state_dict()}
            torch.save(
                state_dict,
                os.path.join(save_dir, f"checkpoint_ep_{epoch}.pth"),
            )


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    # Huggingface's original arguments
    model_name_or_path: Optional[str] = field(
        default="roberta-base",
        metadata={
            "help": "The model checkpoint for weights initialization."
            "Don't set if you want to train a model from scratch."
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )

    # SimCSE's arguments
    temp: float = field(default=0.05, metadata={"help": "Temperature for softmax."})
    pooler_type: str = field(
        default="cls",
        metadata={
            "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)."
        },
    )
    hard_negative_weight: float = field(
        default=0,
        metadata={
            "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)."
        },
    )
    do_mlm: bool = field(
        default=False, metadata={"help": "Whether to use MLM auxiliary objective."}
    )
    mlm_weight: float = field(
        default=0.1,
        metadata={
            "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)."
        },
    )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    kg: str = field(
        default="wikikg2",
        metadata={"help": "KG name"},
    )
    wikikg_dir: str = field(
        metadata={
            "help": "directory containing the processed wikikg2",
        },
    )
    # Huggingface's original arguments.
    dataset: Optional[str] = field(
        default="synthetic",
        metadata={"help": "The name of the dataset to use (via the datasets library)."},
    )
    overwrite_cache: bool = field(
        default=True,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )

    # SimCSE's arguments
    train_data_path: Optional[str] = field(
        metadata={"help": "The training data file (.txt or .csv)."},
    )
    valid_data_path: Optional[str] = field(
        metadata={"help": "The training data file (.txt or .csv)."},
    )
    max_seq_length: Optional[int] = field(
        default=52,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated."
        },
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to `max_seq_length`. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch."
        },
    )
    mlm_probability: float = field(
        default=0.15,
        metadata={
            "help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"
        },
    )

    def __post_init__(self):
        if (
            self.dataset is None
            and self.train_data_path is None
            and self.valid_data_path is None
        ):
            raise ValueError(
                "Need either a dataset name or a training/validation file."
            )
        else:
            if self.train_data_path is not None:
                extension = self.train_data_path.split(".")[-1]
                assert extension in [
                    "csv",
                    "json",
                    "txt",
                ], "`train_file` should be a csv, a json or a txt file."


@dataclass
class OurTrainingArguments(TrainingArguments):
    eval_steps: Optional[int] = field(
        default=2000,
    )
    valid_k: Optional[list[int]] = field(
        default_factory=lambda: [10, 200],
    )
    wandb: Optional[bool] = field(
        default=True,
    )
    save_path: Optional[str] = field(
        default="SR_gt",
    )
    output_dir: Optional[str] = field(default="checkpoints/def/")
    per_device_train_batch_size: Optional[int] = field(default=16)


if __name__ == "__main__":
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, OurTrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    main(model_args, data_args, training_args)
