import os
import argparse
import subprocess
import pandas
import numpy as np

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

class BLASTdistance:
    def __init__(self, seqs_csv, seq_id_col, seq_col, working_dir, blastp_executable_path = None, makeblastdb_executable_path = None, max_workers = None, sep = ';'):
        """
        Main class to run blast algorithm on M2OR main_receptors table.

        Notes:
        ------
        homo_sapiens is treated differently than other species, because there are more reviewed sequences for humans.
        """
        if working_dir is None:
            raise ValueError('working_dir must be explicitly specified.')

        self.working_dir = working_dir
        self.blastp_executable_path = blastp_executable_path
        self.makeblastdb_executable_path = makeblastdb_executable_path
        self.max_workers = max_workers

        self.seqs_csv = seqs_csv
        self.seq_id_col = seq_id_col
        self.seq_col = seq_col

        self.seqs_csv_filename = os.path.splitext(os.path.basename(seqs_csv))[0]

        self.sep = sep

        self.blast_db_title = 'seqsblastdb'
        if not os.path.exists(os.path.join(self.working_dir, 'blast_db')):
            os.makedirs(os.path.join(self.working_dir, 'blast_db'))
        self.seqs_fasta_path = os.path.join(self.working_dir, 'blast_db', self.blast_db_title + ".fasta")
        self.blast_db_path = os.path.join(self.working_dir, 'blast_db', self.blast_db_title)

        
    
    def load_seqs(self):
        seqs = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = None, header = 0)
        return seqs

    @staticmethod
    def create_seq_record_from_row(row, seq_id_col, seq_col):
        mutated_sequence = row[seq_col]
        seq_id = row[seq_id_col]
        record = SeqRecord(Seq(mutated_sequence), id=seq_id, description = 'olfactory receptor')
        return record

    def make_seqs_fasta(self, seqs):
        fasta_data = seqs.apply(lambda x: self.create_seq_record_from_row(x, self.seq_id_col, self.seq_col), axis = 1)
        with open(self.seqs_fasta_path, "w") as output_handle:
            SeqIO.write(fasta_data.to_list(), output_handle, "fasta")
        return
        
    @staticmethod
    def make_blast_db(makeblastdb_executable, seqs_fasta_path, blast_db_title, blast_db_path):
        makeblastdb_cmd = [makeblastdb_executable, "-in", seqs_fasta_path, "-title", blast_db_title, "-dbtype", "prot", "-out", blast_db_path, "-parse_seqids"]
        subprocess.run(makeblastdb_cmd)
        return

    @staticmethod
    def blast_search(blastp_executable, database_path, seqs_fasta_path, n_outputs):
        """
        """
        # blastp_cmd = [blastp_executable, "-db", database_path, "-query", seqs_fasta_path, "-evalue", "10", "-num_alignments", str(n_outputs), "-outfmt", "6 qseqid sseqid pident length mismatch gapopen qstart qend sstart send evalue bitscore"]
        blastp_cmd = [blastp_executable, "-db", database_path, "-query", seqs_fasta_path, "-evalue", "10", "-num_alignments", str(n_outputs), "-outfmt", "6 qseqid sseqid pident length evalue score"]
        blastp_output = subprocess.check_output(blastp_cmd)
        hits = blastp_output.decode().strip().split('\n')        
        hit_data = pandas.DataFrame([hit.split('\t') for hit in hits], columns = ['qseqid','sseqid','pident','length','evalue','score'])
        return hit_data

    def __call__(self):
        seqs = self.load_seqs()
        self.make_seqs_fasta(seqs)
        self.make_blast_db(makeblastdb_executable = self.makeblastdb_executable_path, 
                           seqs_fasta_path = self.seqs_fasta_path, 
                           blast_db_title = self.blast_db_title, 
                           blast_db_path = self.blast_db_path)
        hit_data = self.blast_search(blastp_executable = self.blastp_executable_path,
                          database_path = self.blast_db_path,
                          seqs_fasta_path = self.seqs_fasta_path,
                          n_outputs = len(seqs)**2)
        hit_data.to_csv(os.path.join(self.working_dir, self.seqs_csv_filename + '_similarity.csv'), sep=';', index = False)
        return
