from io import StringIO
import streamlit as st
import pandas as pd
from project.classifiers import AMPClassifier
from hydra import compose, initialize
from project.config import load_model_for_inference
from project.data import save_sequences_to_fasta
from project.conditioning import PartialConditioningTypes
from project.constants import AMINO_ACIDS, CLASSIFIER_MODELS
import os
from Bio import SeqIO
import torch
from project.scripts.inference import run_framework
from project.scripts.inference import generate_samples
from project.sequence_properties import calculate_length, calculate_charge, calculate_hydrophobicity

def format_classifier_name(name):
    return name.replace("-", " ").replace("_", " ").title()

@st.cache_resource
def load_classifiers():
    classifiers = {}
    for classifier in CLASSIFIER_MODELS:
        classifiers[classifier] = AMPClassifier(model_path=CLASSIFIER_MODELS[classifier])
        classifiers[classifier].eval()
    return classifiers

@st.cache_resource
def load_generative_model(generator_path="models/generative_model.ckpt"):
    with initialize(version_base=None, config_path="config/"):
        config = compose(config_name="train")
        return load_model_for_inference(config, generator_path)

def validate_sequences(sequences):
    valid_amino_acids = set(AMINO_ACIDS)
    validated_sequences = []
    for seq in sequences:
        seq = seq.strip().upper()
        if seq and all(aa in valid_amino_acids for aa in seq):
            validated_sequences.append(seq)
    return validated_sequences

def check_template(template):
    """ Check if template is valid. """
    if len(template) == 0:
        return False
    return all(c in AMINO_ACIDS + ['_'] for c in template)

def app(generative_model, classifiers):
    st.title("🧬 OmegAMP: Targeted AMP Discovery via Biologically Informed Generation and High-Confidence Classification")

    menu = ["🔍 Predict AMP Activity", "✨ Generate AMP Sequences", "🚀 Run Framework"]
    choice = st.sidebar.selectbox("Menu", menu)

    # Add consistent styling
    st.markdown("""
        <style>
        @keyframes pulse {
            0% { opacity: 1; }
            50% { opacity: 0.3; }
            100% { opacity: 1; }
        }
        .loading-pulse {
            animation: pulse 1.5s ease-in-out infinite;
            padding: 10px;
            border-radius: 5px;
            background-color: #f0f2f6;
            margin: 10px 0;
        }
        .section-header {
            padding: 10px;
            border-radius: 5px;
            margin: 10px 0;
            background-color: #f8f9fa;
        }
        .stSelectbox {
            margin-bottom: 20px;
        }
        .stButton > button {
            width: 100%;
        }
        </style>
    """, unsafe_allow_html=True)

    if "🔍 Predict AMP Activity" in choice:
        st.markdown("<div class='section-header'>🔍 Predict Antimicrobial Activity of Protein Sequences</div>", unsafe_allow_html=True)

        # Classifier selection with emojis
        classifier_names = ["🎯 Run All Classifiers"] + [f"🧪 {format_classifier_name(name)}" for name in CLASSIFIER_MODELS.keys()]
        selected_classifier = st.selectbox("Select Classifier", classifier_names)
        
        # Map back to actual keys
        classifier_option = "Run All Classifiers" if "Run All Classifiers" in selected_classifier else list(CLASSIFIER_MODELS.keys())[classifier_names.index(selected_classifier) - 1]
        
        col1, col2 = st.columns(2)
        with col1:
            fasta_input = st.file_uploader("📄 Upload a .fasta file", type=["fasta"])
        with col2:
            sequence_input = st.text_area("✍️ Or enter amino acid sequence(s) manually")

        if st.button("🔍 Predict"):
            with st.spinner('🧬 Running predictions...'):
                if fasta_input:
                    sequences = [str(record.seq) for record in SeqIO.parse(StringIO(fasta_input.getvalue().decode("utf-8")), "fasta")]
                elif sequence_input:
                    sequences = [seq.strip() for seq in sequence_input.split("\n") if seq.strip()]
                else:
                    st.warning("Please upload a .fasta file or enter sequences manually.")
                    return

                valid_sequences = validate_sequences(sequences)
                if not valid_sequences:
                    st.warning("No valid sequences found. Please ensure your sequences only contain valid amino acids.")
                    return
                if len(valid_sequences) > 5000:
                    st.warning("Too many sequences. Please upload a file with less than 5000 sequences.")
                    return

                # Run prediction for selected classifier(s)
                results = {}
                if classifier_option == "Run All Classifiers":
                    for name, model in classifiers.items():
                        predictions = model(valid_sequences)
                        results[format_classifier_name(name)] = predictions
                else:
                    predictions = classifiers[classifier_option](valid_sequences)
                    results[format_classifier_name(classifier_option)] = predictions

                # Display results
                combined_results = {}
                for seq in valid_sequences:
                    combined_results[seq] = {}
                
                for name, preds in results.items():
                    for seq, pred in zip(valid_sequences, preds):
                        combined_results[seq][name] = pred
                
                # Convert to DataFrame with sequences as rows and classifiers as columns
                result_df = pd.DataFrame.from_dict(combined_results, orient='index').reset_index()
                result_df.rename(columns={'index': 'Sequence'}, inplace=True)

                # Update results display
                st.markdown("### 📊 Results")
                st.write(result_df)
                st.download_button(
                    '📥 Download Predictions',
                    data=result_df.to_csv(index=False),
                    file_name='prediction_results.csv',
                    mime='text/csv'
                )
                st.success("✨ Predictions ready for download!")

    elif "✨ Generate AMP Sequences" in choice:
        st.markdown("<div class='section-header'>✨ Generate Antimicrobial Peptide Sequences</div>", unsafe_allow_html=True)

        generation_type = st.radio("Generation Type", 
            ["🎲 Unconditional", "🎯 Partial Conditional", "📝 Template Conditional", "🔄 Subset Conditional"])

        no_samples = st.number_input("Number of samples", min_value=1, max_value=5, value=1)

        if generation_type == "🎯 Partial Conditional":
            length_value = st.text_input("Length value (e.g., '1:100', '8', or '-')", value='-')
            charge_value = st.text_input("Charge value (e.g., '4:6', '-2', or '-')", value='-')
            hydrophobicity_value = st.text_input("Hydrophobicity value (e.g., '0:1', '-1', or '-')", value='-')
        elif generation_type == "📝 Template Conditional":
            template = st.text_input("Enter a template sequence", value='A_A_________')
            guidance_strength = st.number_input("Guidance strength", min_value=0.1, max_value=10.0, value=1.0)
        elif generation_type == "🔄 Subset Conditional":
            st.markdown("<div class='section-header'>📚 Input Reference Sequences</div>", unsafe_allow_html=True)
            
            col1, col2 = st.columns(2)
            with col1:
                fasta_input = st.file_uploader("📄 Upload a .fasta file", type=["fasta"])
            with col2:
                sequence_input = st.text_area("✍️ Or enter amino acid sequence(s) manually")

        if st.button("✨ Generate"):
            mode_map = {
                "🎲 Unconditional": "Unconditional",
                "🎯 Partial Conditional": "PartialConditional",
                "📝 Template Conditional": "TemplateConditional",
                "🔄 Subset Conditional": "SubsetConditional"
            }
            
            mode = mode_map[generation_type]
            batch_size = no_samples
            sequences = [] # Initialize sequences variable

            with st.spinner('🧬 Generating sequences...'):
                try:
                    if mode == "TemplateConditional":
                        if not check_template(template):
                            st.warning("Invalid template sequence.")
                            return
                        sequences, conditioning = generate_samples.main(
                            mode=mode,
                            template=template,
                            guidance_strength=guidance_strength,
                            num_samples=no_samples,
                            batch_size=batch_size,
                            model=generative_model,
                        )
                        if sequences:
                            st.text("Generated Sequences:")
                            st.code('\n'.join(sequences))
                            st.download_button(
                                '📥 Download Sequences',
                                data='\n'.join(sequences),
                                file_name='template_generated_sequences.txt',
                                mime='text/plain'
                            )
                    elif mode == "PartialConditional":
                        sequences, conditioning = generate_samples.main(
                            mode=mode,
                            length=length_value,
                            charge=charge_value,
                            hydrophobicity=hydrophobicity_value,
                            num_samples=no_samples,
                            batch_size=batch_size,
                            model=generative_model,
                        )

                        if sequences:
                            actual_lengths = calculate_length(sequences)
                            actual_charges = calculate_charge(sequences)
                            actual_hydrophobicity = calculate_hydrophobicity(sequences)
                            
                            properties_df = pd.DataFrame({
                                "Sequence": sequences,
                                "Length": actual_lengths,
                                "Charge": actual_charges,
                                "Hydrophobicity": actual_hydrophobicity
                            })
                            
                            col1, col2 = st.columns(2)
                            
                            with col1:
                                st.text("Generated Sequences:")
                                st.code('\n'.join(sequences))
                            
                            with col2:
                                st.text("Sequence Properties:")
                                st.dataframe(properties_df)
                                st.download_button(
                                    '📥 Download Sequences with Properties',
                                    data=properties_df.to_csv(index=False),
                                    file_name='sampled_sequences_with_properties.csv',
                                    mime='text/csv'
                                )
                    elif mode == "SubsetConditional":
                        subset_path = None
                        if fasta_input:
                            temp_file = "temp_subset.fasta"
                            with open(temp_file, "w") as f:
                                f.write(fasta_input.getvalue().decode("utf-8"))
                            subset_path = temp_file
                        elif sequence_input:
                            temp_file = "temp_subset.fasta"
                            save_sequences_to_fasta(sequence_input.split("\n"), temp_file)
                            subset_path = temp_file
                        else:
                            st.warning("Please provide subset sequences.")
                            return
                            
                        sequences, conditioning = generate_samples.main(
                            mode=mode,
                            subset_sequences=subset_path,
                            num_samples=no_samples,
                            batch_size=batch_size,
                            model=generative_model,
                        )
                        
                        if os.path.exists("temp_subset.fasta"):
                            os.remove("temp_subset.fasta")
                        
                        if sequences:
                            st.text("Generated Sequences:")
                            st.code('\n'.join(sequences))
                            st.download_button(
                                '📥 Download Sequences',
                                data='\n'.join(sequences),
                                file_name='subset_generated_sequences.txt',
                                mime='text/plain'
                            )
                    elif mode == "Unconditional":
                        sequences, conditioning = generate_samples.main(
                            mode=mode,
                            num_samples=no_samples,
                            batch_size=batch_size,
                            model=generative_model,
                        )
                        if sequences:
                            st.text("Generated Sequences:")
                            st.code('\n'.join(sequences))
                            st.download_button(
                                '📥 Download Sequences',
                                data='\n'.join(sequences),
                                file_name='unconditional_generated_sequences.txt',
                                mime='text/plain'
                            )
                    
                    if sequences:
                        st.success(f"✨ Successfully generated {len(sequences)} sequences!")
                    elif mode_map[generation_type] in ["Unconditional", "TemplateConditional", "SubsetConditional", "PartialConditional"]:
                         st.info("Generation complete, but no sequences were produced with the current settings.")

                except Exception as e:
                    st.error(f"❌ Error: {str(e)}")

    elif "🚀 Run Framework" in choice:
        st.markdown("<div class='section-header'>🚀 Run Complete Framework</div>", unsafe_allow_html=True)

        framework_mode = st.radio("Framework Mode", ["🎲 Unconditional", "🎯 Conditional"])

        # Add mode map
        framework_mode_map = {
            "🎲 Unconditional": "Unconditional",
            "🎯 Conditional": "Conditional"
        }

        col1, col2 = st.columns(2)
        with col1:
            no_samples = st.number_input("Number of sequences", min_value=1, max_value=128, value=128)
            min_length = st.number_input("Minimum length", min_value=1, max_value=50, value=6)
        with col2:
            max_length = st.number_input("Maximum length", min_value=1, max_value=100, value=30)

        if "Conditional" in framework_mode:
            st.markdown("### 🦠 Target Selection")
            strain_species = st.selectbox("Select Target Strain/Species", [
                "🔬 species-acinetobacterbaumannii",
                "🔬 species-escherichiacoli",
                "🔬 species-klebsiellapneumoniae",
                "🔬 species-pseudomonasaeruginosa",
                "🔬 species-staphylococcusaureus",
                "🔬 strains-acinetobacterbaumannii-atcc19606",
                "🔬 strains-escherichiacoli-atcc25922",
                "🔬 strains-klebsiellapneumoniae-atcc700603",
                "🔬 strains-pseudomonasaeruginosa-atcc27853",
                "🔬 strains-staphylococcusaureus-atcc25923",
                "🔬 strains-staphylococcusaureus-atcc33591",
                "🔬 strains-staphylococcusaureus-atcc43300"
            ])
            strain_species = strain_species.split(" ")[1]  # Remove emoji for processing

        if st.button("🚀 Generate and Filter"):
            with st.spinner('🧬 Running framework...'):
                try:
                    batch_size = min(32, no_samples)
                    
                    if framework_mode == "🎲 Unconditional":
                        filtered_sequences, ranked_sequences = run_framework.main(
                            framework_mode_map[framework_mode], 
                            "unconditional",
                            "data/generative-model-data/AMPs.fasta",
                            "models/generative_model.ckpt",
                            max_length, 
                            min_length,
                            no_samples, 
                            batch_size,
                            None,
                            None,
                            generative_model=generative_model,
                        )
                    else:
                        filtered_sequences, ranked_sequences = run_framework.main(
                            framework_mode_map[framework_mode], 
                            strain_species,
                            "data/generative-model-data/AMPs.fasta",
                            "models/generative_model.ckpt",
                            max_length, 
                            min_length,
                            no_samples, 
                            batch_size,
                            None,
                            None,
                            generative_model=generative_model,
                        )


                    if len(filtered_sequences) > 0:
                        # Create two columns for the results
                        col1, col2 = st.columns(2)
                        
                        with col1:
                            st.markdown("### 🧬 Filtered Sequences")
                            st.code('\n'.join(filtered_sequences))
                        
                        with col2:
                            st.markdown("### 📊 Download Options")
                            st.download_button(
                                '📥 Download Ranked Sequences',
                                data=ranked_sequences.to_csv(index=False),
                                file_name='ranked_sequences.csv',
                                mime='text/csv'
                            )
                        
                        # Success message
                        st.success(f"✨ Successfully generated and filtered {len(filtered_sequences)} sequences!")
                    else:
                        st.warning("⚠️ No sequences passed the filtering criteria.")
                        
                except Exception as e:
                    st.error(f"❌ Error: {str(e)}")

if __name__ == '__main__':
    st.set_page_config(
        page_title="OmegAMP",
        page_icon="🧬",
        layout="wide",
        initial_sidebar_state="expanded"
    )
    
    classifiers = load_classifiers()
    generative_model = load_generative_model()
    app(generative_model, classifiers)
