from data.processor import DataProcessor
from data.collector import Collector
from forge.forge import Forge

import argparse


def run(unsupervised_instance_path, triplet_instance_path, lp_gap_instance_path):
    ########################### Unsupervised Pre Training ###########################
    # Unsupervised Training Data
    dp = DataProcessor()
    dp.get_instance_dict(input_path=unsupervised_instance_path,
                         output_file='./data/intermediate_files/mips_to_dgl.pkl',
                         perturb=[0.05, 0.01],
                         return_dict=False)

    train_list = dp.get_train_list(['.data/intermediate_files/mips_to_dgl.pkl'])

    # Unsupervised Pre-Training
    model = Forge()
    model.train_unsupervised(model_save_path='./models/unsupervised_model.pkl',
                             train_list=train_list,
                             epochs=10,
                             steps_per_instance=10,
                             lr=1e-4,
                             log_path='./data/log/unsupervised_train_log.pkl')
    #################################################################################

    ############################ Warm Start Fine Tuning  ############################
    # Triplet Data Collection 
    dc = Collector()
    dc.get_triplets(input_path=triplet_instance_path,
                    gnn_model_path='./models/unsupervised_model.pkl',
                    output_file='./data/intermediate_files/mips_to_triplet.pkl',
                    return_dict=False)

    # Supervised Fine Tuning 
    # Warm Start Prediction 
    model = Forge(prob_head=True)
    model.train_triplets(pretrained_path='./models/unsupervised_model.pkl',
                         model_save_path='./models/warm_start_model.pkl',
                         mips_to_triplet=None,
                         mips_to_triplet_path='./data/intermediate_files/mips_to_triplet.pkl',
                         epochs=10,
                         steps_per_instance=10,
                         lr=1e-5,
                         batch_size=1024)
    #################################################################################

    ############################## LP Gap Fine Tuning ###############################
    # LP Gap Data Collection
    dc.get_cut_ratios(input_path=lp_gap_instance_path,
                      output_file='./data/intermediate_files/mips_to_gaps.pkl',
                      return_dict=False)
    # LP Gap Prediction 
    model = Forge(prob_head=True,
                  cut_head=True)
    model.train_lp_gaps(pretrained_path='./models/warm_start_model.pkl',
                        model_save_path='./models/lp_gap_model.pkl',
                        mips_to_gaps=None,
                        mips_to_gaps_path='./data/intermediate_files/mips_to_gaps.pkl',
                        epochs=10,
                        steps_per_instance=10,
                        lr=1e-4)
    #################################################################################


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--unsupervised_instance_path', type=str, default='./data/train/')
    parser.add_argument('--triplet_instance_path', type=str, default='./data/train/')
    parser.add_argument('--lp_gap_instance_path', type=str, default='./data/train/')
    args = parser.parse_args()

    # For an interactive example run, please see 'train.ipynb'.
    # Ensure instances are placed in data/train
    run(unsupervised_instance_path=args.unsupervided_instance_path,
        triplet_instance_path=args.triplet_instance_path,
        lp_gap_instance_path=args.lp_gap_instance_path)
