import argparse
import json
from pathlib import Path

import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from MPE.utils import pytorch_utils

from .data_utils import MemoryEfficientDepthDataset
from .loss_functions import (
    BBListMLClassficationLossFunction,
    BBListRankingLossFunctionV2,
)
from .metrics import list_bc_metric_v2, list_ranking_metric
from .models import BBAttentionNetwork

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

DATASET_DIR = Path.cwd() / "MPE" / "dataset"
CONFIG = {
    "l2c": {"ml_weight": 0.45, "scoring_weight": 0.55},
}


def load_json_files(dataset_path, network_name):
    json_data = {}
    for file in list(dataset_path.iterdir())[:1000]:
        name = file.name
        if name.endswith(".json"):
            with open(file) as f:
                try:
                    data = json.load(f)
                except:
                    continue
            if data["props"]["network_name"] == network_name:
                json_data[name] = data
    return json_data


def main(args):
    dataset_path = DATASET_DIR / args.network_name / "l2c-dataset"
    json_data = load_json_files(dataset_path, args.network_name)
    max_root_evid_vars = max(j["props"]["num_evid_vars"] for j in json_data.values())
    print(max_root_evid_vars)
    train_json_id, val_json_id = train_test_split(list(json_data.keys()), test_size=1000)
    train_json = {json_id: json_data[json_id] for json_id in train_json_id}
    val_json = {json_id: json_data[json_id] for json_id in val_json_id}
    print(len(train_json), len(val_json))

    train_data = MemoryEfficientDepthDataset(train_json, max_root_evid_vars, train=True)
    val_data = MemoryEfficientDepthDataset(val_json, max_root_evid_vars, train=False)
    print(len(train_data), len(val_data))
    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=128, shuffle=True)
    ml_label_loss_fn = BBListMLClassficationLossFunction(reduction="none")
    scoring_loss_fn = BBListRankingLossFunctionV2(reduction="none", log_target=True)

    config_name = args.config_name
    config_values = CONFIG[config_name]
    num_var_val = len(train_data.var_values_to_idx)
    bb_attention_network = BBAttentionNetwork(
        num_var_val + 1,
        embed_dim=256,
        num_layers=15,
        hidden_dim=512,
        output_size=1,
        padding_idx=num_var_val,
    ).to(device)

    loss_fn_list = [ml_label_loss_fn, scoring_loss_fn]
    loss_weights = [config_values["ml_weight"], config_values["scoring_weight"]]
    learning_rate = 8e-4
    optimizer = torch.optim.Adam(
        bb_attention_network.parameters(),
        lr=learning_rate,
    )
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.97)
    artifacts_dir = DATASET_DIR / args.network_name / "training-artifacts" / config_name
    pytorch_utils.train_test_loop(
        train_dataloader,
        val_dataloader,
        bb_attention_network,
        loss_fn_list,
        loss_weights,
        optimizer,
        scheduler,
        metrics=[list_ranking_metric, list_bc_metric_v2],
        artifacts_dir=artifacts_dir,
        epochs=50,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Train cross-attention neural network for decimation"
    )
    parser.add_argument("--network-name", required=True, dest="network_name")
    parser.add_argument(
        "--config-name",
        dest="config_name",
        type=str,
        choices=list(CONFIG.keys()),
        default="l2c",
    )
    args = parser.parse_args()

    main(args)
