#!/usr/bin/env python

import argparse
import os
import pandas as pd
from project.data import load_fasta_to_df
from project.classifiers import AMPClassifier

def main(hemolytic_file_path, non_hemolytic_file_path, curated_hemolytic_path, curated_nonhemolytic_path, 
         random_sequences_file_path, shuffled_sequences_file_path, mutated_sequences_file_path, output_csv):
    # Load general datasets
    hemolytic_df = load_fasta_to_df(hemolytic_file_path)
    non_hemolytic_df = load_fasta_to_df(non_hemolytic_file_path)

    # Load curated (high quality) datasets
    curated_hemolytic_df = load_fasta_to_df(curated_hemolytic_path)
    curated_nonhemolytic_df = load_fasta_to_df(curated_nonhemolytic_path)

    # Load synthetic datasets
    random_sequences_df = load_fasta_to_df(random_sequences_file_path)
    shuffled_sequences_df = load_fasta_to_df(shuffled_sequences_file_path)
    mutated_sequences_df = load_fasta_to_df(mutated_sequences_file_path)

    # Label all datasets
    hemolytic_df['label'] = 1
    hemolytic_df['high_quality'] = 0
    non_hemolytic_df['label'] = 0
    non_hemolytic_df['high_quality'] = 0

    curated_hemolytic_df['label'] = 1
    curated_hemolytic_df['high_quality'] = 1
    curated_nonhemolytic_df['label'] = 0
    curated_nonhemolytic_df['high_quality'] = 1

    # Label synthetic datasets
    random_sequences_df['label'] = 0
    random_sequences_df['high_quality'] = 0
    shuffled_sequences_df['label'] = 0
    shuffled_sequences_df['high_quality'] = 0
    mutated_sequences_df['label'] = 0
    mutated_sequences_df['high_quality'] = 0

    # Combine all datasets
    sequence_label_df = pd.concat([
        hemolytic_df, 
        non_hemolytic_df, 
        curated_hemolytic_df, 
        curated_nonhemolytic_df,
        random_sequences_df,
        shuffled_sequences_df,
        mutated_sequences_df
    ], ignore_index=True)

    # Check for sequences that have conflicting labels
    conflicts = sequence_label_df.groupby('Sequence').agg({
        'label': 'nunique',
        'high_quality': 'max'  # Will be 1 if any entry is high quality
    })
    conflicting_sequences = conflicts[conflicts['label'] > 1].index
    print(f"Sequences with conflicting labels found: {len(conflicting_sequences)}")
    
    # For each conflicting sequence
    sequences_to_remove = []
    for seq in conflicting_sequences:
        conflicting_entries = sequence_label_df[sequence_label_df['Sequence'] == seq]
        high_quality_entries = conflicting_entries[conflicting_entries['high_quality'] == 1]
        
        if len(high_quality_entries) > 1:
            raise ValueError(f"Sequence found with conflicting high-quality labels: {seq}")
        elif len(high_quality_entries) == 1:
            # Keep only the high-quality entry
            sequences_to_remove.extend(
                conflicting_entries[conflicting_entries['high_quality'] == 0].index
            )
        else:
            # If no high-quality entries, remove all instances
            sequences_to_remove.extend(conflicting_entries.index)
    
    # Remove conflicting sequences
    sequence_label_df = sequence_label_df.drop(sequences_to_remove)
    
    # Print number of duplicates before removal
    total_duplicates = len(sequence_label_df) - len(sequence_label_df['Sequence'].unique())
    print(f"Number of duplicate sequences found: {total_duplicates}")
    
    # Remove remaining duplicates, keeping high quality sequences when available
    sequence_label_df = sequence_label_df.sort_values('high_quality', ascending=False).drop_duplicates(subset='Sequence', keep='first')
    sequence_label_df.reset_index(drop=True, inplace=True)

    print(f"Final dataset size: {len(sequence_label_df)}")

    # Extract features
    classifier = AMPClassifier(model_path=None)
    sequences = sequence_label_df['Sequence'].tolist()
    print(f"Extracting features for {len(sequences)} sequences")
    features_df = classifier.get_input_features(sequences)

    # Combine features with labels
    sequence_label_df = pd.concat([sequence_label_df, features_df], axis=1)

    # Save to CSV with Id as string
    sequence_label_df['Id'] = sequence_label_df['Id'].astype(str)
    sequence_label_df.to_csv(output_csv, index=False)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Create a dataset for hemolytic peptide classification using both general and curated datasets.')
    parser.add_argument('--hemolytic_file_path', type=str, default='data/toxicity-data/hemolytic.fasta', 
                      help='Path to the general hemolytic peptides FASTA file (default: data/toxicity-data/hemolytic.fasta)')
    parser.add_argument('--non_hemolytic_file_path', type=str, default='data/toxicity-data/nonhemolytic.fasta', 
                      help='Path to the general non-hemolytic peptides FASTA file (default: data/toxicity-data/nonhemolytic.fasta)')
    parser.add_argument('--curated_hemolytic_path', type=str, default='data/toxicity-data/curated-hemolytic.fasta', 
                      help='Path to the curated hemolytic peptides FASTA file (default: data/toxicity-data/curated-hemolytic.fasta)')
    parser.add_argument('--curated_nonhemolytic_path', type=str, default='data/toxicity-data/curated-nonhemolytic.fasta', 
                      help='Path to the curated non-hemolytic peptides FASTA file (default: data/toxicity-data/curated-nonhemolytic.fasta)')
    parser.add_argument('--random_sequences_file_path', type=str, 
                      default='data/toxicity-data/synthetic-data/random-sequences.fasta',
                      help='Path to the random sequences FASTA file')
    parser.add_argument('--shuffled_sequences_file_path', type=str, 
                      default='data/toxicity-data/synthetic-data/shuffled-sequences.fasta',
                      help='Path to the shuffled sequences FASTA file')
    parser.add_argument('--mutated_sequences_file_path', type=str, 
                      default='data/toxicity-data/synthetic-data/mutated-sequences.fasta',
                      help='Path to the mutated sequences FASTA file')
    parser.add_argument('--output_csv', type=str, default='data/toxicity-data/hemolytic-dataset.csv', 
                      help='Path where the hemolytic dataset will be saved (default: data/toxicity-data/hemolytic-dataset.csv)')
    args = parser.parse_args()

    main(
        args.hemolytic_file_path,
        args.non_hemolytic_file_path,
        args.curated_hemolytic_path,
        args.curated_nonhemolytic_path,
        args.random_sequences_file_path,
        args.shuffled_sequences_file_path,
        args.mutated_sequences_file_path,
        args.output_csv
    )
