import os
import subprocess
import time

import numpy as np
import pandas as pd


class MMseqs:
    def __init__(self, args):
        self.input_path = args.inputpath
        dir_name = f"s_{args.s}_covmode_{args.cov_mode}_cov_{args.c}_idthr_{args.min_seq_id}_aln_{args.alignment_mode}"
        self.output_dir = os.path.join(args.outputpath, dir_name)
        self.seqdb = os.path.join(self.output_dir, 'seqdb')
        self.prefilterdb = os.path.join(self.output_dir, 'prefilterdb')
        self.aligndb = os.path.join(self.output_dir, 'aligndb')
        self.clustdb = os.path.join(self.output_dir, f"clustdb_mode_{args.cluster_mode}")
        self.alignfile = os.path.join(self.output_dir, 'align.tsv')
        self.clustfile = os.path.join(self.output_dir, 'clust.tsv')
        self.verbose = args.v
        self.sensitivity = args.s
        self.coverage = args.c
        self.coverage_mode = args.cov_mode
        self.threads = args.threads
        self.alignment_mode = args.alignment_mode
        self.alignment_output_mode = args.alignment_output_mode
        self.alignment_outputs = args.alignment_outputs
        self.e_value = args.e
        self.min_seq_id = args.min_seq_id
        self.seq_id_mode = args.seq_id_mode
        self.cluster_mode = args.cluster_mode

    def make_output_dirs(self) -> None:
        for newdir in [self.output_dir, self.seqdb, self.prefilterdb, self.aligndb]:
            os.makedirs(newdir, exist_ok=True)

    def create_db(self):
        t = time.time()
        print(f'Creating sequence database at {self.seqdb}.')
        subprocess.run(['mmseqs', 'createdb', f'{self.input_path}', f'{self.seqdb}/seqdb', '--dbtype', '1', 
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

    def prefilter(self):
        t = time.time()
        print(f'Prefiltering sequence database. Saving prefilter database to {self.prefilterdb}.')
        subprocess.run(['mmseqs', 
                        'prefilter', 
                        f'{self.seqdb}/seqdb', 
                        f'{self.seqdb}/seqdb', 
                        f'{self.prefilterdb}/prefilterdb',
                        '-s', f'{self.sensitivity}',
                        '-c', f'{self.coverage}',
                        '--cov-mode', f'{self.coverage_mode}',
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)
     
    def align(self):
        t = time.time()
        print(f'Aligning prefiltered sequences. Saving alignment database to {self.aligndb}.')
        subprocess.run(['mmseqs', 'align',
                        f'{self.seqdb}/seqdb', 
                        f'{self.seqdb}/seqdb', 
                        f'{self.prefilterdb}/prefilterdb', 
                        f'{self.aligndb}/aligndb',
                        '-a', # Add backtrace string (convert to alignments with mmseqs convertalis module)
                        '--alignment-mode', f'{self.alignment_mode}',
                        '--alignment-output-mode', f'{self.alignment_output_mode}',
                        '-e', f'{self.e_value}',
                        '--min-seq-id', f'{self.min_seq_id}',
                        '--seq-id-mode', f'{self.seq_id_mode}',
                        '-c', f'{self.coverage}',
                        '--cov-mode', f'{self.coverage_mode}',
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

        t = time.time()
        print(f'Convert alignment DB to BLAST-tab format. Save to {self.alignfile}.')
        format_outputs = ','.join(self.alignment_outputs)
        subprocess.run(['mmseqs', 
                        'convertalis',
                        f'{self.seqdb}/seqdb', 
                        f'{self.seqdb}/seqdb', 
                        f'{self.aligndb}/aligndb', 
                        f'{self.alignfile}',
                        '--format-output', f'{format_outputs}',
                        '--format-mode', str(4),    # BLAST-TAB + column headers
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

    def cluster(self):
        t = time.time()
        print(f'Clustering. Saving clustering database to {self.clustdb}.')
        os.makedirs(self.clustdb, exist_ok=True)
        subprocess.run(['mmseqs', 'clust',
                        f'{self.seqdb}/seqdb', 
                        f'{self.aligndb}/aligndb', 
                        f'{self.clustdb}/clustdb',
                        '--cluster-mode', f'{self.cluster_mode}',
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

        t = time.time()
        print(f'Saving cluster member-representative map to {self.clustfile}.')
        subprocess.run(['mmseqs', 
                        'createtsv',
                        f'{self.seqdb}/seqdb', 
                        f'{self.seqdb}/seqdb', 
                        f'{self.clustdb}/clustdb',
                          f'{self.clustfile}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)


class Foldseek(MMseqs):
    def __init__(self, args, **kwargs):
        super().__init__(args, **kwargs)
        dir_name = f"s_{args.s}_covmode_{args.cov_mode}_cov_{args.c}_tmthr_{args.tm_thresh}_aln_{args.alignment_type}_exh" 
        self.output_dir = os.path.join(args.outputpath, dir_name)
        self.structdb = os.path.join(self.output_dir, 'structdb')
        self.clustdb = os.path.join(self.output_dir, f"clustdb_mode_{self.cluster_mode}")
        self.alignment_type = args.alignment_type
        self.tmalign_hit_order = args.tmalign_hit_order
        self.tmalign_fast = args.tmalign_fast
        self.tm_thr = args.tm_thresh
        self.lddt_threshold = args.lddt_threshold

    def make_output_dirs(self) -> None:
        for newdir in [self.output_dir, self.structdb, self.aligndb]:
            os.makedirs(newdir, exist_ok=True)

    def do_easy_search(self) -> None:
        print(self.output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        result_file = os.path.join(self.output_dir, "align.tsv") 
        t = time.time()
        print("Easy search.")
        print(f"Aligning structures all-against-all. Saving alignment results to {result_file}.")
        format_outputs = ','.join(self.alignment_outputs)
        subprocess.run(['foldseek', 
                        'easy-search', 
                        f'{self.input_path}', 
                        f'{self.input_path}', 
                        f'{result_file}',
                        'tmpFolder',
                        '--exhaustive-search', '1', # Skip prefilter and perform an exhaustive alignment (slower but more sensitive)
                        '--remove-tmp-files', '0',
                        '--tmscore-threshold', f'{self.tm_thr}',
                        '--alignment-type', f'{self.alignment_type}',
                        '--tmalign-hit-order', f'{self.tmalign_hit_order}',
                        '--tmalign-fast', f'{self.tmalign_fast}',
                        '--lddt-threshold', f'{self.lddt_threshold}',
                        '-s', f'{self.sensitivity}',
                        '-c', f'{self.coverage}',
                        '--cov-mode', f'{self.coverage_mode}',
                        '-a', # Add backtrace string (convert to alignments with mmseqs convertalis module)
                        '--alignment-mode', f'{self.alignment_mode}',
                        '--alignment-output-mode', f'{self.alignment_output_mode}',
                        '-e', f'{self.e_value}',
                        '--min-seq-id', f'{self.min_seq_id}',
                        '--seq-id-mode', f'{self.seq_id_mode}',
                        '--format-output', f'{format_outputs}',
                        '--format-mode', str(4),    # BLAST-TAB + column headers
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

    def create_db(self):
        t = time.time()
        print(f'Creating sequence database at {self.structdb}.')
        subprocess.run(['foldseek', 'createdb', f'{self.input_path}', f'{self.structdb}/structdb', '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

    def search(self):
        result_file = os.path.join(self.output_dir, "align.tsv") 
        t = time.time()
        print("Search.")
        print(f"Aligning structures all-against-all. Saving alignment DB to {self.aligndb}.")
        format_outputs = ','.join(self.alignment_outputs)
        subprocess.run(['foldseek', 
                        'search', 
                        f'{self.structdb}/structdb', 
                        f'{self.structdb}/structdb', 
                        f'{self.aligndb}/aligndb',
                        'tmpFolder',
                        '--exhaustive-search', '1', # Skip prefilter and perform an exhaustive alignment (slower but more sensitive)
                        '--remove-tmp-files', '0',
                        '--tmscore-threshold', f'{self.tm_thr}',
                        '--alignment-type', f'{self.alignment_type}',
                        '--tmalign-hit-order', f'{self.tmalign_hit_order}',
                        '--tmalign-fast', f'{self.tmalign_fast}',
                        '--lddt-threshold', f'{self.lddt_threshold}',
                        '-s', f'{self.sensitivity}',
                        '-c', f'{self.coverage}',
                        '--cov-mode', f'{self.coverage_mode}',
                        '-a', # Add backtrace string (convert to alignments with mmseqs convertalis module)
                        '--alignment-mode', f'{self.alignment_mode}',
                        '--alignment-output-mode', f'{self.alignment_output_mode}',
                        '-e', f'{self.e_value}',
                        '--min-seq-id', f'{self.min_seq_id}',
                        '--seq-id-mode', f'{self.seq_id_mode}',
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

        t = time.time()
        print(f'Convert alignment DB to BLAST-tab format. Save to {result_file}.')
        format_outputs = ','.join(self.alignment_outputs)
        subprocess.run(['foldseek', 
                        'convertalis',
                        f'{self.structdb}/structdb', 
                        f'{self.structdb}/structdb', 
                        f'{self.aligndb}/aligndb', 
                        f'{result_file}',
                        '--format-output', f'{format_outputs}',
                        '--format-mode', str(4),    # BLAST-TAB + column headers
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

    def cluster(self):
        t = time.time()
        print(f'Clustering. Saving clustering database to {self.clustdb}.')
        os.makedirs(self.clustdb, exist_ok=True)

        subprocess.run(['foldseek', 'clust',
                        f'{self.structdb}/structdb', 
                        f'{self.aligndb}/aligndb', 
                        f'{self.clustdb}/clustdb',
                        '--cluster-mode', f'{self.cluster_mode}',
                        '--cluster-reassign',  # Corrects criteria-violoations of cascaded merging
                        '--threads', f'{self.threads}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

        t = time.time()
        print(f'Saving cluster member-representative map to {self.clustfile}.')
        subprocess.run(['foldseek', 
                        'createtsv',
                        f'{self.structdb}/structdb',
                        f'{self.structdb}/structdb',
                        f'{self.clustdb}/clustdb',
                          f'{self.clustfile}',
                        '-v', f'{self.verbose}'])
        print(f'{time.time() - t} s.')
        time.sleep(0.01)

