from __future__ import annotations
import argparse
import torch
from recbole.quick_start import load_data_and_model
from recbole.config import Config
from recbole.data import create_dataset, data_preparation

from .train import train_cul, freeze_all_params


def _torch_load_allow_pickle_patch():
    import torch as _t

    _old = _t.load

    def _torch_load_allow_pickle(f, *args, **kwargs):
        if "weights_only" not in kwargs:
            kwargs["weights_only"] = False
        return _old(f, *args, **kwargs)

    _t.load = _torch_load_allow_pickle


def build_cul_config_from_base(
    base_config, batch_size: int, eval_batch_size: int
) -> Config:
    cul_cfg_dict = {
        "data_path": base_config["data_path"],
        "dataset": base_config["dataset"],
        "USER_ID_FIELD": base_config["USER_ID_FIELD"],
        "ITEM_ID_FIELD": base_config["ITEM_ID_FIELD"],
        "TIME_FIELD": base_config["TIME_FIELD"],
        "load_col": {"inter": ["user_id", "item_id", "timestamp", "req_text", "sid"]},
        "field_separator": base_config["field_separator"],
        "seq_separator": base_config["seq_separator"],
        "eval_args": base_config["eval_args"],
        "train_batch_size": batch_size,
        "loss_type": "BPR",
        "train_neg_sample_args": {"distribution": "uniform", "sample_num": 1},
        "eval_batch_size": eval_batch_size,
    }
    return Config(cul_cfg_dict)


def main():
    _torch_load_allow_pickle_patch()

    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=str, default="saved.pth")
    ap.add_argument("--epochs", type=int, default=30)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--eval_batch_size", type=int, default=1024)
    ap.add_argument("--lr", type=float, default=1e-4)
    ap.add_argument("--plm", type=str, default="answerdotai/ModernBERT-base")
    ap.add_argument("--rho", type=float, default=0.10)
    ap.add_argument("--tau", type=float, default=1.0)
    ap.add_argument("--lambda_kl", type=float, default=1.0)
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--shared_mask", action="store_true")
    ap.add_argument(
        "--plm_tune", type=str, default="all", choices=["none", "last", "all"]
    )
    ap.add_argument("--plm_last_layers", type=str, default="-1:")
    ap.add_argument(
        "--mask_scope", type=str, default="ffn", choices=["ffn", "attn", "both"]
    )
    ap.add_argument(
        "--backend", type=str, default="sasrec", choices=["sasrec", "bert4rec"]
    )
    ap.add_argument(
        "--tau_schedule",
        type=str,
        default="linear",
        choices=["linear", "exp", "cosine"],
    )
    ap.add_argument("--tau_start", type=float, default=0.7)
    ap.add_argument("--tau_end", type=float, default=0.3)
    ap.add_argument("--save_root", type=str, default=None)

    args = ap.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print("========== Load frozen base model ==========")
    base_config, base_model, base_dataset, base_train, base_valid, base_test = (
        load_data_and_model(args.ckpt)
    )
    base_model.to(device)
    base_model.train(True)
    freeze_all_params(base_model)

    print("========== Rebuild dataset/dataloaders with req_text & sid ==========")
    cul_config = build_cul_config_from_base(
        base_config, args.batch_size, args.eval_batch_size
    )
    dataset = create_dataset(cul_config)
    train_data, valid_data, test_data = data_preparation(cul_config, dataset)
    print(
        f"[Data] dataset={cul_config['dataset']} users={dataset.user_num} items={dataset.item_num}"
    )

    print("========== Build specs, init CUL, train ==========")
    run_id, meta = train_cul(
        base_model,
        dataset,
        train_data,
        valid_data,
        test_data,
        epochs=args.epochs,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        lr=args.lr,
        plm=args.plm,
        rho=args.rho,
        tau=args.tau,
        lambda_kl=args.lambda_kl,
        device=device,
        shared_mask=args.shared_mask,
        plm_tune=args.plm_tune,
        plm_last_layers=args.plm_last_layers,
        mask_scope=args.mask_scope,
        backend=args.backend,
        tau_schedule=args.tau_schedule,
        tau_start=args.tau_start,
        tau_end=args.tau_end,
        data_path=cul_config["data_path"],
        dataset_name=cul_config["dataset"],
        base_config=base_config,
        save_root=args.save_root,
    )
    print(f"[Done] run_id={run_id}")


if __name__ == "__main__":
    main()
