#!/usr/bin/env python

import argparse

from project.scripts.inference import filter_generated_sequences, generate_samples

def main(mode, strain_species, path_to_training_set, checkpoint_path, max_length, min_length, 
         num_samples, batch_size, output_fasta, output_strain_specific_csv, generative_model=None):
    if mode == "Conditional" and strain_species == "unconditional":
        raise ValueError("Strain or species must be specified in conditional mode")

    if mode == 'Unconditional':
        samples, conditioning = generate_samples.main(
            mode="SubsetConditional", 
            subset_sequences="data/activity-data/curated-AMPs.fasta",
            checkpoint_path=checkpoint_path, 
            num_samples=num_samples, 
            batch_size=batch_size,
            output_fasta=None,
            conditioning_output_path=None,
            model=generative_model
        )
        
        return filter_generated_sequences.main(None, path_to_training_set, max_length, min_length, "unconditional",
                                        output_fasta, output_strain_specific_csv, input_sequences=samples)
    elif mode == 'Conditional':
        taxonomy = strain_species.split('-')[0]
        strain_species_name = "".join(strain_species.split('-')[1:])

        samples, conditioning = generate_samples.main(
            mode="SubsetConditional",
            subset_sequences=f"data/activity-data/strain-species-data/{taxonomy}/{strain_species_name}_positive.fasta", 
            checkpoint_path=checkpoint_path, 
            num_samples=num_samples,
            batch_size=batch_size,
            output_fasta=None,
            conditioning_output_path=None,
            model=generative_model
        )
        
        return filter_generated_sequences.main(None, path_to_training_set, max_length, min_length, strain_species,
                                        output_fasta, output_strain_specific_csv, input_sequences=samples)



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Load a model checkpoint and generate sequences.')
    parser.add_argument('mode', type=str, choices=['Unconditional', 'Conditional'], help='Mode of sampling: Unconditional, Conditional')
    parser.add_argument('--strain_species', type=str, choices=['unconditional',
                                                               'species-acinetobacterbaumannii', 'species-escherichiacoli', 
                                                               'species-klebsiellapneumoniae', 'species-pseudomonasaeruginosa', 
                                                               'species-staphylococcusaureus',
                                                               'strains-acinetobacterbaumannii-atcc19606',
                                                               'strains-escherichiacoli-atcc25922',
                                                               'strains-klebsiellapneumoniae-atcc700603',
                                                               'strains-pseudomonasaeruginosa-atcc27853',
                                                               'strains-staphylococcusaureus-atcc25923',
                                                               'strains-staphylococcusaureus-atcc33591',
                                                               'strains-staphylococcusaureus-atcc43300'], 
                                                               default='unconditional', help='Strain or species to condition on')
    parser.add_argument('--path_to_training_set', type=str, default='data/generative-model-data/AMPs.fasta', help='Path to the training set in FASTA format')
    parser.add_argument('--checkpoint_path', type=str, default='models/generative_model.ckpt', help='Path to the model checkpoint')
    parser.add_argument('--max_length', type=int, default=30, help='Maximum length of the sequences')
    parser.add_argument('--min_length', type=int, default=6, help='Minimum length of the sequences')
    parser.add_argument('--num_samples', type=int, default=32, help='Number of sequences to sample')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for sampling')
    parser.add_argument('--output_fasta', type=str, default='results/framework-results/script-filtered-sequences.fasta', help='Path to save the filtered sequences in FASTA format')
    parser.add_argument('--output_strain_specific_csv', type=str, default='results/framework-results/script-ranked-sequences.csv', help='Path to save the rank sequences in csv format')
    

    args = parser.parse_args()

    main(args.mode, args.strain_species, args.path_to_training_set, args.checkpoint_path, args.max_length, args.min_length, args.num_samples, args.batch_size, args.output_fasta, args.output_strain_specific_csv, generative_model=None)