import torch
from ml_collections import ConfigDict
import argparse

from eval import init_model


def process_cs3(state_dict):
    state_dict["cfg"]["model_config"][1].tokenizer_config["vq_codebook_dim"] = 512


def fix_checkpoint(src, target, do_test, device, signal_length, new_token_path):
    print("Loading checkpoint")
    state_dict = torch.load(src, weights_only=False)

    t_cfg = state_dict["cfg"]["model_config"][1]["transformer_config"].copy_and_resolve_references()
    t_cfg["encoder_causality"] = (None, 0 if t_cfg["causal_encoder"] else None)
    t_cfg["decoder_causality"] = (None, 0 if t_cfg["causal_decoder"] else None)
    if "causal_cross" not in t_cfg or not t_cfg["causal_cross"]:
        t_cfg["cross_causality"] = (None, None)
    elif "lookahead_tokens" not in t_cfg:
        t_cfg["cross_causality"] = (None, 0)
    else:
        t_cfg["cross_causality"] = (None, t_cfg["lookahead_tokens"])
    del t_cfg.causal_encoder
    del t_cfg.causal_decoder
    if "causal_cross" in t_cfg:
        del t_cfg.causal_cross
    if signal_length is not None:
        state_dict["cfg"]["model_config"][1].tokenizer_config["signal_length"] = signal_length
    state_dict["cfg"]["model_config"][1]["transformer_config"] = t_cfg
    if new_token_path is not None:
        state_dict["cfg"]["model_config"][1].tokenizer_path = new_token_path
    if "soi_type" not in state_dict["cfg"]:
        state_dict["cfg"]["soi_type"] = "old"

    print("Saving checkpoint")
    torch.save(state_dict, target)

    if do_test:
        print("Re-loading checkpoint")
        state_dict = torch.load(target, weights_only=False)
        cfg = ConfigDict(state_dict["cfg"])
        init_model(state_dict, cfg, device)
        print("Successfully loaded model")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("src", type=str)
    parser.add_argument("target", type=str)
    parser.add_argument("--do_test", type=bool, default=True)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--signal_length", type=int)
    parser.add_argument("--new_token_path", type=str)
    return parser.parse_args()


def main():
    args = parse_args()
    fix_checkpoint(args.src, args.target, args.do_test, args.device, args.signal_length, args.new_token_path)


if __name__ == "__main__":
    main()
