import time
import arg
import os
import yaml

os.environ['MPLCONFIGDIR'] = os.getcwd() + "/configs/"

def main():
    parser = arg.get_parser()
    args = parser.parse_args()

    if args.embedder == 'LP':
        from models import LP
        embedder = LP(args)

    elif args.embedder == 'GCNMF': 
        from models import GCNMF
        embedder = GCNMF(args)
    
    elif args.embedder == 'PaGNN':
        from models import PaGNN
        embedder = PaGNN(args)

    elif args.embedder == 'Node2Vec':
        from models import node2vec
        embedder = node2vec(args)

    elif args.embedder == 'Node2Vec_x': 
        from models import node2vec_x
        embedder = node2vec_x(args)

    elif args.embedder == 'Node2Vec_loss': 
        from models import node2vec_loss
        embedder = node2vec_loss(args)

    elif args.embedder == 'Node2Vec_GNN':
        from models import node2vec_gnn
        embedder = node2vec_gnn(args)

    elif args.embedder == 'node2vec_gnn_concat_1':
        from models import node2vec_gnn_concat_1
        embedder = node2vec_gnn_concat_1(args)

    elif args.embedder == 'node2vec_gnn_concat_2':
        from models import node2vec_gnn_concat_2
        embedder = node2vec_gnn_concat_2(args)

    elif args.embedder == 'MLP': # Filling Variants: "zero", "random", "neighbor_mean", "fp"
        from models import mlp
        embedder = mlp(args)

    elif args.embedder == 'GNN': # Filling Variants: "zero", "random", "neighbor_mean", "fp"
        from models import gnn
        embedder = gnn(args)
    
    elif args.embedder == 'SAT': # Filling Variants: "zero", "random", "neighbor_mean", "fp"
        from models import SAT
        embedder = SAT(args)
    
    elif args.embedder == 'GCN_LPA': # Filling Variants: "zero", "random", "neighbor_mean", "fp"
        from models import GCN_LPA
        embedder = GCN_LPA(args)

    elif args.embedder == 'Correct_Smooth': # Filling Variants: "zero", "random", "neighbor_mean", "fp"
        from models import Correct_Smooth
        embedder = Correct_Smooth(args)

    elif args.embedder == 'TWIRLS':
        with open("hyperparameters_twirls.yaml", "r") as f:
            hyperparams = yaml.safe_load(f)
            if args.dataset in hyperparams:
                for k, v in hyperparams[args.dataset].items():
                    setattr(args, k, v)
        from models import TWIRLS
        embedder = TWIRLS(args)

    elif args.embedder == 'LP_label_trick':
        from models import LP_label_trick
        embedder = LP_label_trick(args)

    elif args.embedder == 'GOODIE':
        from models import GOODIE
        embedder = GOODIE(args)

    elif args.embedder == 'GOODIE_batch':
        with open(f"hyperparameters_goodie.yaml", "r") as f:
            hyperparams = yaml.safe_load(f)
            dataset = args.dataset + f'_{args.missing_type}'
            if dataset in hyperparams:
                for k, v in hyperparams[dataset].items():
                    setattr(args, k, v)      

        from models import GOODIE_batch
        embedder = GOODIE_batch(args)

    elif args.embedder == 'GOODIE_best':
        # scaled = '_scaled' if args.scaled else ''   # tmp
        with open(f"hyperparameters_goodie.yaml", "r") as f:
            hyperparams = yaml.safe_load(f)
            dataset = args.dataset + f'_{args.missing_type}'
            if dataset in hyperparams:
                for k, v in hyperparams[dataset].items():
                    setattr(args, k, v)        
        
        # args.embeer = 'GOODIE'
        from models import GOODIE
        embedder = GOODIE(args)



    elif args.embedder == 'GOODIE_best_wo_k':
        with open(f"hyperparameters_goodie_wo_k.yaml", "r") as f:
            hyperparams = yaml.safe_load(f)
            dataset = args.dataset + f'_{args.missing_type}'
            if dataset in hyperparams:
                for k, v in hyperparams[dataset].items():
                    setattr(args, k, v)        
        
        # args.embeer = 'GOODIE'
        from models import GOODIE
        embedder = GOODIE(args)


    elif args.embedder == 'GOODIE_best_wo_lamb':
        with open(f"hyperparameters_goodie_wo_lamb.yaml", "r") as f:
            hyperparams = yaml.safe_load(f)
            dataset = args.dataset + f'_{args.missing_type}'
            if dataset in hyperparams:
                for k, v in hyperparams[dataset].items():
                    setattr(args, k, v)        
        
        # args.embeer = 'GOODIE'
        from models import GOODIE
        embedder = GOODIE(args)



    elif args.embedder == 'GOODIE_mlp':
        from models import GOODIE_mlp
        embedder = GOODIE_mlp(args)

    elif args.embedder == 'GOODIE_attn':
        from models import GOODIE_attn
        embedder = GOODIE_attn(args)

    elif args.embedder == 'GOODIE_2_attn':
        from models import GOODIE_2_attn
        embedder = GOODIE_2_attn(args)

    elif args.embedder == 'GOODIE_pseudo':
        from models import GOODIE_pseudo
        embedder = GOODIE_pseudo(args)

    elif args.embedder == 'GOODIE_2':
        from models import GOODIE_2
        embedder = GOODIE_2(args)

    elif args.embedder == 'GOODIE_2_pseudo':
        from models import GOODIE_2_pseudo
        embedder = GOODIE_2_pseudo(args)

    t_total = time.time()
    embedder.training()
    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

if __name__ == '__main__':
    main()
