import argparse
import os
from typing import Dict, Any

import tqdm

import torch
import torch.cuda
from torch.utils.data import DataLoader
import torch.backends.cudnn
import torch.nn as nn

import cv_lib.utils as cv_utils

import dark_kg.graph as graph
from dark_kg.data import build_train_dataset
import dark_kg.utils as utils


@torch.no_grad()
def init_graph(
    dataloader: DataLoader,
    wrapper: utils.IngredientModelWrapper,
    graph: graph.RelationGraph,
    device: torch.device
):
    nn.init.zeros_(graph.vertex_weights.tensor)
    nn.init.zeros_(graph.edge_weights.tensor)
    for x, gt in tqdm.tqdm(dataloader, total=len(dataloader)):
        x, gt = utils.move_data_to_device(x, gt, device)
        output: Dict[str, torch.Tensor] = wrapper(x)
        graph.update(
            ingredients=output["ingredients"],
            attn=output["attn"],
            attn_cls=output["attn_cls"],
            label=gt["label"]
        )


def main(args):
    # split configs
    cfg: Dict[str, Any] = cv_utils.get_cfg(args.dark_kg_cfg)
    data_cfg: Dict[str, Any] = cv_utils.get_cfg(cfg["dataset"])
    kg_cfg = cfg["relation_graph"]

    # set cuda
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # make deterministic
    if args.seed is not None:
        cv_utils.make_deterministic(args.seed)

    # get dataloader
    print("Building dataset...")
    # data_cfg["make_partial"] = args.make_partial
    train_dataset, _, n_classes, _ = build_train_dataset(data_cfg)
    dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        shuffle=True
    )

    # load jit models
    print("Loading jit models...")
    backbone: torch.jit.ScriptModule = torch.jit.load(kg_cfg["backbone_jit"], map_location=device)
    codebook: torch.jit.ScriptModule = torch.jit.load(kg_cfg["codebook_jit"], map_location=device)
    wrapper = utils.IngredientModelWrapper(backbone, codebook)
    # create relation graph
    relation_graph = graph.RelationGraph(
        num_vertices=wrapper.num_ingredients,
        emb_dim=wrapper.emb_dim,
        num_classes=n_classes,
        **kg_cfg["ir_atlas"]
    ).to(device)
    wrapper.eval().to(device)
    init_graph(
        dataloader,
        wrapper,
        graph=relation_graph,
        device=device
    )
    relation_graph.accumulate()
    state_dict = relation_graph.state_dict()
    torch.save(state_dict, args.save_fp)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dark_kg_cfg", type=str)
    parser.add_argument("--save_fp", type=str)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--make_partial", type=float, default=None)
    args = parser.parse_args()
    save_path = os.path.dirname(args.save_fp)
    os.makedirs(save_path, exist_ok=True)
    main(args)
