import argparse
import json
import os
import pickle

from EVE.EVE import VAE_model
from EVE.utils import data_utils

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VAE")
    parser.add_argument("--MSA_location", type=str, help="MSA FASTA file path_in")
    parser.add_argument("--protein_family", type=str, help="Protein family")
    parser.add_argument(
        "--MSA_weights_location",
        type=str,
        help="Location where weights for each sequence in the MSA will be stored",
    )
    parser.add_argument(
        "--dataset_pickle",
        type=str,
        help="Path for pickled dataset. Will be generated if not found.",
    )
    parser.add_argument(
        "--theta_reweighting",
        type=float,
        help="Parameters for MSA sequence re-weighting",
    )
    parser.add_argument(
        "--VAE_checkpoint_location",
        type=str,
        help="Location where VAE model checkpoints will be stored",
    )
    parser.add_argument(
        "--model_name_suffix",
        default="Jan1",
        type=str,
        help="model checkpoint name will be the protein name followed by this suffix",
    )
    parser.add_argument(
        "--model_parameters_location", type=str, help="Location of VAE model parameters"
    )
    parser.add_argument(
        "--training_logs_location", type=str, help="Location of VAE model parameters"
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    msa_location = args.MSA_location
    protein_family = args.protein_family

    print("Protein family: " + str(protein_family))
    print("MSA file: " + str(msa_location))

    if args.theta_reweighting is not None:
        theta = args.theta_reweighting
    else:
        theta = 0.2

    print("Theta MSA re-weighting: " + str(theta))

    # If pickled dataset not found, generate and save
    if not os.path.exists(args.dataset_pickle):
        print("Pickled dataset not found. Pre-processing...")
        if args.protein_family == "cm":
            threshold_sequence_frac_gaps = 0.5
            threshold_focus_cols_frac_gaps = 0.3
            theta = 0.2
        elif args.protein_family == "ppat":
            threshold_sequence_frac_gaps = 0.5
            threshold_focus_cols_frac_gaps = 0.3
            theta = 0.2
        elif args.protein_family == "tim":
            threshold_sequence_frac_gaps = 0.95
            threshold_focus_cols_frac_gaps = 0.3
            theta = 0.2
        else:
            raise ValueError

        data = data_utils.MSA_processing(
            MSA_location=args.MSA_location,
            protein_family=args.protein_family,
            theta=theta,
            use_weights=True,
            weights_location=f"{args.MSA_weights_location}{os.sep}"
            f"{protein_family}_theta_{str(theta)}.npy",
            threshold_sequence_frac_gaps=threshold_sequence_frac_gaps,
            threshold_focus_cols_frac_gaps=threshold_focus_cols_frac_gaps,
        )
        with open(args.dataset_pickle, "wb") as handle:
            pickle.dump(data, handle)
            print("Dataset saved with pickle.")

    # Load data
    with open(args.dataset_pickle, "rb") as handle:
        data = pickle.load(handle)
        print("Dataset loaded with pickle.")

    model_name = f"{protein_family}_{args.model_name_suffix}"
    print(f"Model name: {model_name}")

    model_params = json.load(open(args.model_parameters_location))

    model = VAE_model.VAE_model(
        model_name=model_name,
        data=data,
        encoder_parameters=model_params["encoder_parameters"],
        decoder_parameters=model_params["decoder_parameters"],
        random_seed=args.seed,
    )
    model = model.to(model.device)

    model_params["training_parameters"][
        "training_logs_location"
    ] = args.training_logs_location
    model_params["training_parameters"][
        "model_checkpoint_location"
    ] = args.VAE_checkpoint_location

    print("Starting to train model: " + model_name)
    model.train_model(
        data=data, training_parameters=model_params["training_parameters"]
    )

    print("Saving model: " + model_name)
    model.save(
        model_checkpoint=model_params["training_parameters"][
            "model_checkpoint_location"
        ]
        + os.sep
        + model_name
        + "_final",
        encoder_parameters=model_params["encoder_parameters"],
        decoder_parameters=model_params["decoder_parameters"],
        training_parameters=model_params["training_parameters"],
    )
