import argparse
from copy import deepcopy
from multiprocessing import Pool
import numpy as np
import logging
import os
import sys
# from dataManipulation import *
sys.path.append("..")
from utils import mcmc_treeprob,loadData, summary, namenum
from datasets import get_empdataloader
from models import VBPI

def main(args):
    data_path = '../data/hohna_datasets_fasta/'
    ###### Load Data
    unorderdata, unordertaxa = loadData(data_path + args.dataset + '.fasta', 'fasta')
    taxa = sorted(unordertaxa)
    indexs = [unordertaxa.index(taxa[i]) for i in range(len(taxa))]
    data = [unorderdata[indexs[i]] for i in range(len(indexs))]
    del unorderdata, unordertaxa
    
    name =  args.gradMethod + '_hL_' + str(args.hLBranch) + '_aggr' + args.aggr
    if args.proj:
        name = name + '_proj'
    name = name + '_' + args.date
    args.folder = os.path.join(args.workdir, args.dataset, name)
    os.makedirs(args.folder, exist_ok=False)

    args.save_to_path = os.path.join(args.folder, 'final.pt')
    args.logpath = os.path.join(args.folder, 'final.log')

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    filehandler = logging.FileHandler(args.logpath)
    filehandler.setLevel(logging.INFO)
    logger.addHandler(filehandler)

    logger.info('Training with the following settings:')
    for name, value in vars(args).items():
        logger.info('{} : {}'.format(name, value))

    if args.empFreq:
        emp_tree_freq = get_empdataloader(args.dataset,  batch_size=200)
        logger.info('Empirical estimates from MrBayes loaded')
    else:
        emp_tree_freq = None
    
    model = VBPI(taxa, data, pden=np.ones(4)/4., subModel=('JC', 1.0), emp_tree_freq=emp_tree_freq, hidden_dim_tree=args.hdimTree, hidden_dim_branch=args.hdimBranch, num_layers_branch=args.hLBranch, gnn_type=args.gnn_type, aggr=args.aggr, project=args.proj, nheads=args.nheads)

    logger.info('Running on device: {}'.format(model.device))  
    logger.info('Parameter Info:')
    for param in model.parameters():
        logger.info(param.dtype)
        logger.info(param.size())

    logger.info('\nVBPI running, results will be saved to: {}\n'.format(args.save_to_path))
    test_lb, test_kl_div = model.learn({'tree':args.stepszTree,'branch':args.stepszBranch}, args.maxIter, test_freq=args.tf, 
    lb_test_freq=args.lbf, n_particles=args.nParticle, anneal_freq_tree=args.afTree, anneal_freq_branch=args.afBranch, anneal_freq_tree_warm=args.afTreewarm, anneal_freq_branch_warm=args.afBranchwarm, anneal_rate_tree=args.arTree, anneal_rate_branch=args.arBranch,init_inverse_temp=args.invT0, save_freq = args.sf, 
    warm_start_interval=args.nwarmStart, method=args.gradMethod, save_to_path=args.save_to_path, logger=logger, clip_grad=args.clip_grad, 
    clip_value=args.clip_value,  eps_max=args.eps_max, eps_period=args.eps_period)
                
    np.save(args.save_to_path.replace('.pt', '_test_lb.npy'), test_lb)
    if args.empFreq:
        np.save(args.save_to_path.replace('.pt', '_kl_div.npy'), test_kl_div)

def parse_args():
    parser = argparse.ArgumentParser()

    ######### Data arguments
    parser.add_argument('--dataset', default='DS1', help=' DS1 | DS2 | DS3 | DS4 | DS5 | DS6 | DS7 | DS8 ')
    parser.add_argument('--empFreq', default=False, action='store_true', help='emprical frequence for KL computation') 
    ######### Model arguments
    parser.add_argument('--nf', type=int, default=2, help=' branch length feature embedding dimension')
    parser.add_argument('--hdimTree', type=int, default=100, help='hidden dimension for node embedding net')
    parser.add_argument('--hdimBranch', type=int, default=100, help='hidden dimension for node embedding net')
    parser.add_argument('--hLBranch',  type=int, default=2, help='number of hidden layers for node embedding net of branch model')
    parser.add_argument('--gnn_type', type=str, default='edge', help='gcn | sage | gin | ggnn')
    parser.add_argument('--aggr', type=str, default='sum', help='sum | mean | max')
    parser.add_argument('--nheads', default=4, type=int)
    parser.add_argument('--proj', default=False, action='store_true', help='use projection first in SAGEConv')
    ######### Optimizer arguments
    parser.add_argument('--stepszTree', type=float, default=0.0001, help=' step size for tree topology parameters ')
    parser.add_argument('--stepszBranch', type=float, default=0.001, help=' stepsz for branch length parameters ')
    parser.add_argument('--maxIter', type=int, default=400000, help=' number of iterations for training, default=400000')
    parser.add_argument('--invT0', type=float, default=0.001, help=' initial inverse temperature for annealing schedule, default=0.001')
    parser.add_argument('--nwarmStart', type=int, default=100000, help=' number of warm start iterations, default=100000')
    parser.add_argument('--nParticle', type=int, default=10, help='number of particles for variational objectives, default=10')
    parser.add_argument('--arTree', type=float, default=0.75, help='step size anneal rate, default=0.75')
    parser.add_argument('--arBranch', type=float, default=0.75, help='step size anneal rate, default=0.75')
    parser.add_argument('--afTreewarm', type=int, default=20000)
    parser.add_argument('--afBranchwarm', type=int, default=20000)
    parser.add_argument('--afTree', type=int, default=20000, help='step size anneal frequency, default=20000')
    parser.add_argument('--afBranch', type=int, default=20000)
    parser.add_argument('--tf', type=int, default=100, help='monitor frequency during training, default=1000')
    parser.add_argument('--lbf', type=int, default=5000, help='lower bound test frequency, default=5000')
    parser.add_argument('--sf', type=int, default=20000, help='Frequency of saving model')
    parser.add_argument('--gradMethod', type=str, default='vimco', help=' vimco | rws ')
    parser.add_argument('--clip_grad', default=False, action='store_true')
    parser.add_argument('--clip_value', type=float, default=500.0)

    parser.add_argument('--eps_max', default=0.0, type=float)
    parser.add_argument('--eps_period', default=20000, type=int)

    parser.add_argument('--workdir', default='results', type=str)
    parser.add_argument('--date', default='2023-03-16', type=str)
    parser.add_argument('--eval', default=False, action='store_true')

    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()
    main(args)
    sys.exit()