import os
import json
import torch
import random
import warnings
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from argparse import ArgumentParser
from collections import defaultdict
from utils.dataset import CustomDatasetP
from torch.optim import lr_scheduler
from torch_geometric.data import Batch
from torch_geometric.data import DataLoader
from torch_geometric.utils import from_networkx
from utils.utils import get_ori_task_types, pad_graph, get_logger
from models.cf_backbones import NCFPlus

MAX_NODES = 14
NUM_MODEL_ZOO = 70


def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    # torch.use_deterministic_algorithms(True)


def build_dataset(
    missing_choice_ratio,
    missing_sample_ratio,
    t_embs_file,
    v_embs_file,
    program_embs_file,
    program_file,
    instance_file,
):
    ## 1. load extracted features
    X_t = np.load(f"embs/{t_embs_file}")
    X_v = np.load(f"embs/{v_embs_file}")
    X_p = np.load(f"embs/program_embs/{program_embs_file}")

    ## 2. get merged program list
    with open(program_file, "r") as file:
        ori_program_list = json.load(file)

    ori_program_list = get_ori_task_types(ori_program_list)

    ## 3. get instance results
    with open(instance_file, "r") as file:
        instance_results = json.load(file)

    ## 4. dataset & subtest
    train_dataset, val_dataset, test_dataset = [], [], []
    val_structure_list, val_group_list, test_structure_list, test_group_list = (
        [],
        [],
        [],
        [],
    )

    ## 5. load preprocessed computation graph
    with open("data/preprocess/node2idx.json", "r") as file:
        node2idx = json.load(file)

    with open("data/preprocess/processed_graphs.json", "r") as file:
        processed_graphs = json.load(file)

    for key, graph in processed_graphs.items():
        processed_graphs[key] = nx.node_link_graph(graph)

    ## 6. load splitted train / test
    with open(f"data/train_random_list_0.6.json", "r") as file:
        train_random_list = json.load(file)
    with open(f"data/val_random_list_0.6.json", "r") as file:
        val_random_list = json.load(file)

    ## 7. iteration
    logger.info("start data preprocess:")
    for cnt, (id, item_list) in tqdm(enumerate(instance_results.items())):
        # if cnt >= 1000: continue

        ## 7.1. get meta data
        meta_data = ori_program_list[int(id) - 1]
        structure = meta_data["types"]["structural"]

        ## 7.2. get embeddings
        t_embs = torch.tensor(X_t[cnt].tolist(), dtype=torch.float)
        v_embs = torch.tensor(X_v[cnt].tolist(), dtype=torch.float)
        p_embs = torch.tensor(X_p[cnt].tolist(), dtype=torch.float)

        ## 7.3. only valid path
        flag = False
        sample_executable_choice = 0
        for item in item_list:
            y = list(item.values())[-1]
            sample_executable_choice += y
            if y == 1:
                flag = True

        ## 7.4. preprocess data
        if flag:
            if cnt in val_random_list:
                val_structure_list.append(structure)
                val_group_list.append(sample_executable_choice // 14)
            elif cnt not in train_random_list:
                test_structure_list.append(structure)
                test_group_list.append(sample_executable_choice // 14)

            G_list, t_embs_list, v_embs_list, p_embs_list, time_list = (
                [],
                [],
                [],
                [],
                [],
            )
            for idx, item in enumerate(item_list):
                _, _, time, y = item.values()
                uid = cnt * len(item_list) + idx
                G = processed_graphs[str(uid)]
                G = pad_graph(G, MAX_NODES)
                G = from_networkx(G)

                if cnt in train_random_list and random.random() <= missing_choice_ratio:
                    G.y = torch.tensor(3, dtype=torch.float32)
                else:
                    G.y = torch.tensor(y, dtype=torch.float32)

                G_list.append(G)
                t_embs_list.append(t_embs)
                v_embs_list.append(v_embs)
                p_embs_list.append(p_embs)
                time_list.append(time)

            batch_G = Batch.from_data_list(G_list)
            batch_t_embs = torch.stack(t_embs_list)
            batch_v_embs = torch.stack(v_embs_list)
            batch_p_embs = torch.stack(p_embs_list)
            batch_time = torch.tensor(time_list, dtype=torch.float)
            instance = (batch_G, batch_t_embs, batch_v_embs, batch_p_embs, batch_time)

            if cnt in train_random_list:
                if random.random() > missing_sample_ratio:
                    train_dataset.append(instance)
            elif cnt in val_random_list:
                val_dataset.append(instance)
            else:
                test_dataset.append(instance)

    logger.info(
        f"train : {len(train_dataset)}, val : {len(val_dataset)}, test: {len(test_dataset)}"
    )
    # print(f"valid sample count: {(len(train_dataset) + len(test_dataset))}")
    train_dataset = CustomDatasetP(train_dataset)
    val_dataset = CustomDatasetP(val_dataset)
    test_dataset = CustomDatasetP(test_dataset)
    return (
        train_dataset,
        val_dataset,
        test_dataset,
        node2idx,
        val_structure_list,
        test_structure_list,
        val_group_list,
        test_group_list,
    )


def evaluate_model(
    bestId, model, data_loader, structure_list, group_list, tag, epoch, verbose
):
    model.eval()
    with torch.no_grad():
        sample_correct, sample_cnt, instance_correct, instance_cnt = 0, 0, 0, 0
        struct2SamCorr, struct2SamCnt = defaultdict(float), defaultdict(float)
        group2SamCorr, group2SamCnt = defaultdict(float), defaultdict(float)
        max_index_list, time_list = [], []
        baseline2score, baseline2time = defaultdict(list), defaultdict(list)
        struct2baseline2score, struct2baseline2time = {}, {}
        group2baseline2score, group2baseline2time = {}, {}
        baselineId2floatscore, baselineId2rank = defaultdict(float), defaultdict(
            float
        )  # higher => better
        baseline2predictions = defaultdict(list)

        for (G, t, v, p, time), structure, group in zip(
            data_loader, structure_list, group_list
        ):
            group = "G" + str(group)
            sample_cnt += 1
            G, t, v, p = (
                G.to(torch.device("cuda")),
                t.to(torch.device("cuda")),
                v.to(torch.device("cuda")),
                p.to(torch.device("cuda")),
            )

            output = model(G, t, v, p, None, False).squeeze()
            max_index = torch.argmax(output)

            time = time.squeeze()

            time_list.append(time[max_index])
            sample_correct += G.y[max_index].item()
            max_index_list.append(max_index)

            elements = list(output.cpu().numpy())
            sorted_elements = sorted(elements)
            element_ranks = [sorted_elements.index(x) + 1 for x in elements]

            struct2SamCnt[structure] += 1
            struct2SamCorr[structure] += G.y[max_index].item()
            group2SamCnt[group] += 1
            group2SamCorr[group] += G.y[max_index].item()

            ## baseline
            if structure not in struct2baseline2score.keys():
                struct2baseline2score[structure] = defaultdict(list)
                struct2baseline2time[structure] = defaultdict(list)

            if group not in group2baseline2score.keys():
                group2baseline2score[group] = defaultdict(list)
                group2baseline2time[group] = defaultdict(list)

            for i in range(len(output)):
                baseline2predictions[i].append(float(elements[i]))
                baselineId2rank[i] += element_ranks[i]
                baselineId2floatscore[i] += output[i].item()
                baseline2score[i].append(float(G.y[i].item()))
                baseline2time[i].append(time[i].item())
                struct2baseline2score[structure][i].append(G.y[i].item())
                struct2baseline2time[structure][i].append(time[i].item())
                group2baseline2score[group][i].append(G.y[i].item())
                group2baseline2time[group][i].append(time[i].item())

            y = G.y.cpu().numpy()
            output = output.cpu().numpy()

            binary_predictions = [1.0 if p > 0 else 0.0 for p in output]
            correct_predictions = sum(
                [1 for p, gt in zip(binary_predictions, y) if p == gt]
            )

            instance_correct += correct_predictions
            instance_cnt += len(y)

        execution_success_rate = sample_correct / sample_cnt
        binary_classification_accuracy = instance_correct / instance_cnt
        average_execution_time = np.mean(time_list)
        ms_bestId = max(baselineId2floatscore, key=baselineId2floatscore.get)
        rank_bestId = max(baselineId2rank, key=baselineId2floatscore.get)

        ## best baseline:
        bestTime = np.mean(baseline2time[bestId])
        bestScore = np.mean(baseline2score[bestId])

        ms_bestTime = np.mean(baseline2time[ms_bestId])
        ms_bestScore = np.mean(baseline2score[ms_bestId])

        if verbose:
            for structure in struct2SamCnt.keys():
                score = struct2SamCorr[structure] / struct2SamCnt[structure]
                base_score = np.mean(struct2baseline2score[structure][bestId])
                ms_score = np.mean(struct2baseline2score[structure][ms_bestId])
                rank_score = np.mean(struct2baseline2score[structure][rank_bestId])
                content = f"[{structure}]({len(struct2baseline2score[structure][bestId])}): {base_score:.4f}, {ms_score:.4f}, {rank_score:.4f}, {score:.4f}, {(score-base_score)/base_score*100:.2f}%"
                logger.info(content)

            for group in group2SamCnt.keys():
                score = group2SamCorr[group] / group2SamCnt[group]
                base_score = np.mean(group2baseline2score[group][bestId])
                ms_score = np.mean(group2baseline2score[group][ms_bestId])
                rank_score = np.mean(group2baseline2score[group][rank_bestId])
                content = f"[{group}]({len(group2baseline2score[group][bestId])}): {base_score:.4f}, {ms_score:.4f}, {rank_score:.4f}, {score:.4f}, {(score-base_score)/base_score*100:.2f}%"
                logger.info(content)

        # with open(f"predictions/{tag}_{epoch}_baseline_predictions.json", "w") as file:
        #     json.dump(baseline2predictions, file)

        # with open(f"predictions/{tag}_{epoch}_baseline_labels.json", "w") as file:
        #     json.dump(baseline2score, file)

        return (
            bestScore,
            bestTime,
            execution_success_rate,
            binary_classification_accuracy,
            average_execution_time,
        )


def train_model(
    model,
    val_structure_list,
    test_structure_list,
    val_group_list,
    test_group_list,
    fixed_train_loader,
    train_loader,
    val_loader,
    test_loader,
    args,
):
    ## parse args
    epochs, wd, opt, gamma, lr, verbose = (
        args.epochs,
        args.weight_decay,
        args.optimizer,
        args.gamma,
        args.lr,
        args.verbose,
    )

    if opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    elif opt == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd)
    elif opt == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=gamma)

    ## dataset-level baseline
    baseline2score, baseline2time = defaultdict(list), defaultdict(list)
    for data, _, _, _, time in fixed_train_loader:
        for i in range(len(time)):
            for j in range(NUM_MODEL_ZOO):
                baseline2score[j].append(data.y[i * NUM_MODEL_ZOO + j])
                baseline2time[j].append(time[i][j])

    bestId, bestScore = 1, 0
    for k, v in baseline2score.items():
        score = np.mean(v)
        if score > bestScore:
            bestId = k
            bestScore = score

    ## training epoch by epoch
    current_max_test_ser, current_max_val_ser, best_epoch = 0.0, 0.0, 0
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        loop = tqdm(enumerate(train_loader), total=len(train_loader))
        for idx, (G, t, v, p, _) in loop:
            ## to CUDA
            G, t, v, p = (
                G.to(torch.device("cuda")),
                t.to(torch.device("cuda")),
                v.to(torch.device("cuda")),
                p.to(torch.device("cuda")),
            )

            ## back propogation
            optimizer.zero_grad()
            labels = G.y
            loss = model(G, t, v, p, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=total_loss / (idx + 1))

        (
            val_bestScore,
            val_bestTime,
            val_execution_success_rate,
            val_binary_classification_accuracy,
            val_average_execution_time,
        ) = evaluate_model(
            bestId,
            model,
            val_loader,
            val_structure_list,
            val_group_list,
            "val",
            epoch,
            verbose,
        )
        (
            test_bestScore,
            test_bestTime,
            test_execution_success_rate,
            test_binary_classification_accuracy,
            test_average_execution_time,
        ) = evaluate_model(
            bestId,
            model,
            test_loader,
            test_structure_list,
            test_group_list,
            "test",
            epoch,
            verbose,
        )

        if val_execution_success_rate > current_max_val_ser:
            current_max_val_ser = val_execution_success_rate
            best_epoch = epoch
            best_testing_ser = test_execution_success_rate

        content1 = f"** [{epoch+1}] Val Execution Success Rate: {val_execution_success_rate:.4f} / {val_bestScore:.4f}, Test Execution Success Rate: {test_execution_success_rate:.4f} / {test_bestScore:.4f}"
        content2 = f"** [{epoch+1}] Val Avg. Execution Time: {val_average_execution_time:.4f} / {val_bestTime:.4f}, Test Avg. Execution Time: {test_average_execution_time:.4f} / {test_bestTime:.4f}"
        content3 = f"** [{epoch+1}] Val Instance Accuracy: {val_binary_classification_accuracy:.4f}, Test Instance Accuracy: {test_binary_classification_accuracy:.4f}"
        content4 = "-----------------" * 3

        for content in [content1, content2, content3, content4]:
            logger.info(content)

    logger.info(
        f"best epoch: {best_epoch+1}, best testing SER: {best_testing_ser:.4f}!\n\n"
    )


if __name__ == "__main__":
    ## args
    parser = ArgumentParser()
    parser.add_argument("--log_tag", type=str, default="tmp")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--hidden_size", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dropout", type=float, default=0.0)
    parser.add_argument("--gamma", type=float, default=0.6)
    parser.add_argument("--weight_decay", type=float, default=0.0)
    parser.add_argument("--optimizer", type=str, default="adam")
    parser.add_argument("--encoder", type=str, default="standard")
    parser.add_argument("--program_encoder", type=str, default="bert")
    parser.add_argument("--missing_choice_ratio", type=float, default=0.0)
    parser.add_argument("--missing_sample_ratio", type=float, default=0.0)
    parser.add_argument("--loss", type=str, default="cce")
    parser.add_argument("--verbose", type=bool, default=False)
    parser.add_argument("--epochs", type=int, default=10)

    args = parser.parse_args()
    log_tag = args.log_tag
    batch_size, seed, hidden_size, dropout, encoder, loss, program_encoder = (
        args.batch_size,
        args.seed,
        args.hidden_size,
        args.dropout,
        args.encoder,
        args.loss,
        args.program_encoder,
    )
    missing_choice_ratio, missing_sample_ratio = (
        args.missing_choice_ratio,
        args.missing_sample_ratio,
    )

    ## config
    result_file = "data/gqa_model_selection_instance_results.json"
    prompt_file = "data/gqa_computation_graph_descrption.json"
    LOG_FILE = f"logs/{log_tag}.txt"
    logger = get_logger(LOG_FILE)
    args_dict = vars(args)

    for key, value in args_dict.items():
        formatted_key = key.ljust(20)
        formatted_value = str(value).ljust(20)
        logger.info("%s: %s", formatted_key, formatted_value)

    if encoder == "standard":
        t_embs_file, v_embs_file = "t_bert_embs_10k.npy", "v_vit_embs_10k.npy"
    elif encoder == "vilt":
        t_embs_file, v_embs_file = "vilt_embs_10k.npy", "vilt_embs_10k.npy"
    elif encoder == "blip":
        t_embs_file, v_embs_file = "blip_embs_10k.npy", "blip_embs_10k.npy"

    if program_encoder == "bert":
        program_embs_file = "program_bert_embs_10k.npy"
    elif program_encoder == "roberta":
        program_embs_file = "program_roberta_embs_10k.npy"
    elif program_encoder == "sbert":
        program_embs_file = "program_sber_embs_10k.npy"

    ## Create and train the model
    setup_seed(seed)
    (
        train_dataset,
        val_dataset,
        test_dataset,
        node2idx,
        val_structure_list,
        test_structure_list,
        val_group_list,
        test_group_list,
    ) = build_dataset(
        missing_choice_ratio,
        missing_sample_ratio,
        t_embs_file,
        v_embs_file,
        program_embs_file,
        prompt_file,
        result_file,
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    fixed_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    ## logging
    model = NCFPlus(
        hidden_dim=hidden_size,
        node_type_num=len(node2idx),
        dropout=dropout,
        model_zoo_num=NUM_MODEL_ZOO,
        max_node_num=MAX_NODES,
        loss=loss,
    ).to(torch.device("cuda"))

    train_model(
        model,
        val_structure_list,
        test_structure_list,
        val_group_list,
        test_group_list,
        fixed_train_loader,
        train_loader,
        val_loader,
        test_loader,
        args,
    )
