import argparse
import json
import os
import subprocess
import sys

import datasets
import torch
from fire import Fire
from sentence_transformers import SentenceTransformer
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from dataloaders import P3CLDataModule, BBCLDataModule
from dataloaders.constants import TAG2TASK_LIST
from trainers import training_state
from trainers.evaluator import Evaluator
from utils.config import Config, ParseKwargs
from utils.get_model import hf_tokenizer
from utils.util import get_logger, setup_wandb_logger

sys.path.insert(1, os.getcwd())
datasets.disable_progress_bar()


def main(config, loggers):
    assert config.eval_split in ["train", "val"], "Should be train or val split when inferring sub-samples."

    num_samples = 128

    tokenizer = hf_tokenizer(config.origin_model)

    if "bigbench" in config.dataset[0]:
        data_type = "bigbench"
        instruction_version_postfix = None
    else:
        data_type = "p3"
        instruction_version_postfix = config.name[-2:]
        if instruction_version_postfix[0] != "v":
            raise KeyError(f"Run name is not corrct (f{config.name}).")

    if len(config.dataset) == 1 and config.dataset[0] in TAG2TASK_LIST.keys():
        config.dataset = TAG2TASK_LIST[config.dataset[0]]
    all_tasks = config.dataset

    if data_type == "bigbench":
        datamodule = BBCLDataModule(
            config, tokenizer, loggers, stage=config.eval_split, max_examples_per_dataset=num_samples
        )
    elif data_type == "p3":
        datamodule = P3CLDataModule(
            config, tokenizer, loggers, stage=config.eval_split, max_examples_per_dataset=num_samples
        )
    else:
        raise KeyError(f"Unknown data type {data_type}")
    skipped_tasks = []

    model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True).cuda()
    instruction_embeddings = {}

    for task_idx, infer_task in enumerate(all_tasks):

        evaluator = Evaluator(
            config=config,
            eval_tasks=[infer_task],
            tokenizer=tokenizer,
            datamodule=datamodule,
            loggers=loggers,
        )
        try:
            task_dataset = evaluator.datamodule(infer_task).dataset[config.eval_split]

            # ====== Sub-sample ====== Start ======

            # # random select num_samples examples
            # random_indices = random.sample(range(len(task_dataset)), num_samples)
            # all_input_str = [task_dataset[i]["input_str"] for i in random_indices]
            # all_input_str = [f"clustering: {s}" for s in all_input_str]  # prefix for nomic model
            # all_embeddings = []
            #
            # batch_size = task_dataset.batch_size
            #
            # for i in range(0, num_samples, batch_size):
            #     batch_input_str = all_input_str[i:i + batch_size]
            #     with torch.no_grad():
            #         embeddings = model.encode(batch_input_str, convert_to_tensor=True)
            #     embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
            #     matryoshka_dim = 512
            #     embeddings = embeddings[:, :matryoshka_dim]
            #     all_embeddings.append(embeddings)
            #
            # all_embeddings = torch.cat(all_embeddings, dim=0)
            # all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
            #
            # torch.save(all_embeddings, os.path.join(config.checkpoint_dir, f"{infer_task}_nomic_sample_embedding.pt"))

            # ====== Sub-sample ====== End ======

            # instruction_str = task_dataset.chatgpt_instruction_v3  # todo: t0 / chatgpt
            if data_type == "p3":
                instruction_str = getattr(task_dataset, f"chatgpt_instruction_{instruction_version_postfix}")
                print(f"Using chatgpt_instruction_{instruction_version_postfix} for P3 task {infer_task}")
            elif data_type == "bigbench":
                instruction_str = getattr(task_dataset, f"chatgpt_instruction")
                print(f"Using chatgpt_instruction for Bigbench task {infer_task}")
            else:
                raise KeyError(f"Unknown data type {data_type}")
            assert instruction_str is not None, f"Instruction not found for task {infer_task}"
            embeddings = model.encode(instruction_str, convert_to_tensor=True)
            instruction_embeddings[infer_task] = embeddings.cpu()

        except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
            if "CUDA out of memory" in str(e):
                print(f"Skip task {infer_task} due to CUDA out of memory.")
                skipped_tasks.append(infer_task)
            else:
                raise e

    torch.save(instruction_embeddings, os.path.join(config.checkpoint_dir, f"nomic_instruction_embedding.pt"))

    if len(skipped_tasks) > 0:
        print(f"Skipped tasks due to CUDA OOM: {skipped_tasks}")


def main_simple(
        checkpoint_dir: str,
        instruction_dict_path: str
):
    with open(instruction_dict_path, "r") as f:
        instruction_dict = json.load(f)

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True).cuda()
    instruction_embeddings = {}
    for task, instruction in tqdm(instruction_dict.items()):
        embeddings = model.encode(instruction, convert_to_tensor=True)
        instruction_embeddings[task] = embeddings.cpu()

    torch.save(instruction_embeddings, os.path.join(checkpoint_dir, f"nomic_instruction_embedding.pt"))


def main_setup():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_files", required=True)
    parser.add_argument("-k", "--kwargs", nargs="*", action=ParseKwargs, default={})
    args = parser.parse_args()

    config = Config(args.config_files, args.kwargs)

    log_config = os.path.join(config.project_dir, "utils/")
    logger = get_logger("log.txt", f"{config.log_dir}/", log_config)

    logger.info(f"Start experiment {config.project_name}/{config.name}")
    logger.info(config.to_json())

    config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loggers = {"logger": logger}
    if not config.debug:
        loggers["tb"] = SummaryWriter(config.run_output_dir)
        loggers["wandb"], _, _ = setup_wandb_logger(config.__dict__)
        loggers["wandb"].summary["command"] = subprocess.list2cmdline(["python"] + sys.argv)

    training_state.global_training_step = 0

    return config, loggers


if __name__ == "__main__":
    # config, loggers = main_setup()
    # main(config, loggers)
    Fire(main_simple)
