#!/usr/bin/env python

import argparse
import torch
from project.config import load_model_for_inference
from project.data import load_fasta_to_df, save_sequences_to_fasta
from hydra import compose, initialize
from project.conditioning import PartialConditioningTypes


def main(mode='Unconditional',
         length='-',
         charge='-',
         hydrophobicity='-',
         subset_sequences='data/activity-data/curated-AMPs.fasta',
         return_idxs=False,
         checkpoint_path='models/generative_model.ckpt',
         analog='GAAKRAKTAL',
         initial_timestep=250,
         conditioning_closeness=0.5,
         template='A_A_________',
         guidance_strength=1.0,
         output_fasta='results/generative-model-results/script-omegamp-generated-samples.fasta',
         conditioning_output_path='results/generative-model-results/script-omegamp-generated-conditioning.pt',
         num_samples=32,
         batch_size=32,
         advanced_conditioning_modes='TemplateConditional,AnalogConditional',
         model=None):
    print("Loading model from checkpoint...")

    if model is None:
        with initialize(version_base=None, config_path="../../../config"):
            config = compose(config_name="train")
        model = load_model_for_inference(config, checkpoint_path)
    
    kwargs = {}
    kwargs['output_embedding'] = False
    kwargs['batch_size'] = batch_size


    if return_idxs:
        idxs = {"idxs": torch.tensor([])}
    else:
        idxs = None

    # Handle unconditional or conditional sampling
    if mode == 'Unconditional':
        print("Sampling unconditional sequences...")
        sequences, conditioning = model.sample(num_samples, **kwargs)
    elif mode == 'SubsetConditional':
        print("Sampling sequences conditioned on a subset...")
        subset_sequences = load_fasta_to_df(subset_sequences)['Sequence'].tolist()
        sequences, conditioning = model.sample_with_subset_conditioning(subset_sequences, num_samples, return_idxs=idxs, kwargs=kwargs)
    elif mode == 'PartialConditional':
        print("Sampling sequences conditioned on partial conditioning information...")
        partial_conditioning_info = {}
        partial_conditioning_info['length'] = parse_specification(length)
        partial_conditioning_info['charge'] = parse_specification(charge)
        partial_conditioning_info['hydrophobicity_eisenberg'] = parse_specification(hydrophobicity)
        partial_conditioning_info['isAMP'] = (PartialConditioningTypes.DEFINED, 1)
        sequences, conditioning = model.sample_with_partial_conditioning(partial_conditioning_info, num_samples, kwargs=kwargs)
    elif mode == 'TemplateConditional':
        print("Sampling sequences conditioned on a template...")
        actual_template, template_mask = get_actual_template_and_mask(template)
        sequences, conditioning = model.sample_with_template(actual_template, template_mask, guidance_strength, num_samples, kwargs=kwargs)
    elif mode == 'AnalogConditional':
        print("Sampling sequences conditioned on an analog...")
        sequences, conditioning = model.sample_with_analog(analog, initial_timestep, conditioning_closeness, num_samples, batch_size, kwargs=kwargs)
    elif mode == 'AdvancedConditional':
        print("Sampling sequences conditioned on advanced information...")
        advanced_conditioning_modes = advanced_conditioning_modes.split(',')
        for mode in advanced_conditioning_modes:
            if mode == 'TemplateConditional':
                actual_template, template_mask = get_actual_template_and_mask(template)
                kwargs = model.update_template_conditioning_params(kwargs, actual_template, template_mask, guidance_strength, num_samples)
            elif mode == 'AnalogConditional':
                kwargs = model.update_analog_conditioning_params(kwargs, analog, batch_size, initial_timestep)
            elif mode == 'SubsetConditional':
                subset_sequences = load_fasta_to_df(subset_sequences)['Sequence'].tolist()
                kwargs = model.update_subset_conditioning_params(kwargs, subset_sequences, num_samples, return_idxs=idxs)
            elif mode == 'PartialConditional':
                partial_conditioning_info = {}
                partial_conditioning_info['length'] = parse_specification(length)
                partial_conditioning_info['charge'] = parse_specification(charge)
                partial_conditioning_info['hydrophobicity_eisenberg'] = parse_specification(hydrophobicity)
                partial_conditioning_info['isAMP'] = (PartialConditioningTypes.DEFINED, 1)
                kwargs = model.update_partial_conditioning_params(kwargs, partial_conditioning_info, num_samples)
        sequences, conditioning = model.sample(num_samples, **kwargs)

    # Save conditioning information to a file
    if conditioning_output_path is not None:
        torch.save(conditioning, conditioning_output_path)
        print(f"Conditioning information saved to {conditioning_output_path}")

    # Save sequences to a fasta file
    if output_fasta is not None:
        if idxs is not None:
            inspired_sequences = [subset_sequences[idx] for idx in idxs["idxs"]]
            save_sequences_to_fasta(sequences, output_fasta, inspired_sequences=inspired_sequences)
        else:
            save_sequences_to_fasta(sequences, output_fasta)
        print(f"Generated {len(sequences)} sequences saved to {output_fasta}")

    return sequences, conditioning

def get_actual_template_and_mask(template):
    actual_template = "A" * len(template) # dummy version of the template
    template_mask = torch.zeros(len(template))
    for idx, aa in enumerate(template):
        if aa != '_':
            template_mask[idx] = 1
            actual_template = actual_template[:idx] + aa + actual_template[idx+1:]
    return actual_template, template_mask

def parse_specification(spec):
    if spec == '-':
        return (PartialConditioningTypes.UNDEFINED, None)
    if ":" in spec:
        return (PartialConditioningTypes.INTERVAL, list(map(float, spec.split(":"))))
    return (PartialConditioningTypes.DEFINED, float(spec))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Load a model checkpoint and generate sequences.')
    parser.add_argument('mode', type=str, choices=['Unconditional', 'SubsetConditional', 'PartialConditional', 'TemplateConditional', 'AnalogConditional', 'AdvancedConditional'], help='Mode of sampling: Unconditional, SubsetConditional, PartialConditional, Template Conditional, Analog Conditional, Analog Template Conditional')
    parser.add_argument('--length', type=str, default='-', help='Length specification for conditional sampling')
    parser.add_argument('--charge', type=str, default='-', help='Charge specification for conditional sampling')
    parser.add_argument('--hydrophobicity', type=str, default='-', help='Hydrophobicity specification for conditional sampling')
    parser.add_argument('--subset_sequences', type=str, default='data/activity-data/curated-AMPs.fasta', help='Path to the subset of sequences to condition on')
    parser.add_argument('--return_idxs', type=bool, default=False, help='Return the indices of the sequences sampled')
    parser.add_argument('--checkpoint_path', type=str, default='models/generative_model.ckpt', help='Path to the model checkpoint')
    parser.add_argument('--analog', type=str, default='GAAKRAKTAL', help='Analog sequence for conditional sampling')
    parser.add_argument('--initial_timestep', type=int, default=250, help='Initial timestep for conditional sampling')
    parser.add_argument('--conditioning_closeness', type=float, default=0.5, help='Closeness of conditioning for conditional sampling')
    parser.add_argument('--template', type=str, default='A_A_________', help='Template sequence for conditional sampling')
    parser.add_argument('--guidance_strength', type=float, default=1, help='Guidance strength for sampling')
    parser.add_argument('--output_fasta', type=str, default='results/generative-model-results/script-omegamp-generated-samples.fasta', help='Path to the output FASTA file for generated sequences')
    parser.add_argument('--conditioning_output_path', type=str, default='results/generative-model-results/script-omegamp-generated-conditioning.pt', help='Path to save the conditioning information')
    parser.add_argument('--num_samples', type=int, default=32, help='Number of sequences to sample')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for sampling')
    parser.add_argument('--advanced_conditioning_modes', type=str, default='[TemplateConditional, AnalogConditional]', help='Advanced conditioning modes for sampling')

    args = parser.parse_args()


    main(
        mode=args.mode,
        length=args.length,
        charge=args.charge,
        hydrophobicity=args.hydrophobicity,
        subset_sequences=args.subset_sequences,
        return_idxs=args.return_idxs,
        checkpoint_path=args.checkpoint_path,
        analog=args.analog,
        initial_timestep=args.initial_timestep,
        conditioning_closeness=args.conditioning_closeness,
        template=args.template,
        guidance_strength=args.guidance_strength,
        output_fasta=args.output_fasta,
        conditioning_output_path=args.conditioning_output_path,
        num_samples=args.num_samples,
        batch_size=args.batch_size,
        advanced_conditioning_modes=args.advanced_conditioning_modes
    )