import os
import sys
import argparse
import logging
from tqdm import tqdm

import torch
import numpy as np

#. model
from TopoDiff.model.diffusion import Diffusion
from TopoDiff.model.latent_diffusion import LatentDiffusion
from TopoDiff.model.aux_head import SCHead

#. config
from TopoDiff.config.config import model_config
from TopoDiff.config.latent_config import latent_model_config

#. data
from TopoDiff.data.structure import StructureBuilder

#. pdb
from myopenfold.np import protein

#. sampler
from TopoDiff.experiment.sampler import Sampler

logger = logging.getLogger(__file__)


def parse_args():
    parser = argparse.ArgumentParser(description='Run sampling experiment with TopoDiff model')

    # outdir
    parser.add_argument('-o', '--outdir', type=str, default=None, help='The output directory')

    # length
    parser.add_argument('-s', '--start', type=int, default=100, help='The start length of sampling, must be larger than 50, default: 100')
    parser.add_argument('-e', '--end', type=int, default=100, help='The end length of sampling (inclusive), must be smaller than 250, default: 100')
    parser.add_argument('-i', '--interval', type=int, default=10, help='The interval of sampling length, default: 10')
    parser.add_argument('-n', '--num_samples', type=int, default=5, help='The number of samples to generate for each length, default: 5')

    # designability cutoff
    parser.add_argument('--min_sc', type=float, default=0.0, help='The minimum predicted designability score of the latent, default: 0.0')
    parser.add_argument('--max_sc', type=float, default=10.0, help='The maximum predicted designability score of the latent, default: 10.0')

    # seed
    parser.add_argument('--seed', type=int, default=None, help='The random seed for sampling, default: None')

    # gpu
    parser.add_argument('--gpu', type=int, default=None, help='The gpu id for sampling, default: None')

    # length prediction
    parser.add_argument('--num_k', type=int, default=1, help='The number of k to decide the expected length of the latent, default: 1')
    parser.add_argument('--epsilon', type=float, default=0.2, help='The range of variation of the expected length of the latent, default: 0.2')

    return parser.parse_args()

if __name__ == '__main__':

    cur_path = os.path.dirname(os.path.realpath(__file__))
    print(cur_path)

    args = parse_args()

    # check validity of args
    if args.outdir is None:
        raise ValueError('Output directory is not specified')
    
    if args.start < 50:
        raise ValueError('The start length must be larger than 50')
    
    if args.end > 250:
        raise ValueError('The end length must be smaller than 250')
    
    if args.start > args.end:
        raise ValueError('The start length must be smaller than the end length')
    
    if args.num_samples < 1:
        raise ValueError('The number of samples must be larger than 0')
    
    if args.min_sc > args.max_sc:
        raise ValueError('The minimum designability score must be smaller than the maximum designability score')
    
    if args.num_k < 1:
        raise ValueError('The number of k for length prediction must be larger than 0')

    if args.epsilon < 0:
        raise ValueError('The epsilon for length prediction must be larger than 0')
    
    # ourdir to absolute path
    print(args.outdir)
    args.outdir = os.path.abspath(args.outdir)
    print(args.outdir)

    # set up config
    sampling_config = {
        # outdir
        'output_dir': args.outdir,
        'script_dir': cur_path,

        # length
        'sample_per_length' : args.num_samples,
        'length_start' : args.start,
        'length_end' : args.end,
        'length_step' : args.interval,

        # seed
        'seed': args.seed,

        # gpu
        'gpu': args.gpu,
        
        # designability cutoff
        'min_sc': args.min_sc,
        'max_sc': args.max_sc,

        # length prediction
        'length_k': args.num_k,
        'length_epsilon': args.epsilon,

        'model_preset': 'ckpt_neurips_workshop',  # NOTE currently fixed
        'ckpt_path': os.path.join(cur_path, 'weight/neurips_workshop.pt'), # NOTE currently fixed

        'latent_batch_size': 1000, # NOTE currently fixed
        'structure_batch_size': 1, # NOTE currently fixed
    }
    
    # set up sampler
    sampler = Sampler(sampling_config)

    # run sampling
    sampler.run()


