import os
import sys
import time
import subprocess
from subprocess import DEVNULL
import argparse
import wandb
import tqdm
import concurrent.futures
import numpy as np
from collections import defaultdict
from Levenshtein import distance as levenshtein_distance
from src.utils.hamming_distance import hamming_distance_postprocessed
from src.utils.helper_functions import create_fasta_file, read_fasta
from src.eval_pkg.reconstruction_algorithms.trellis_reconstruction.algorithms.trellis_bma import TrellisBMAParams, compute_trellis_bma_estimation
from src.eval_pkg.reconstruction_algorithms.VSAlgorithm.mainVS import alg
from src.eval_pkg.majority_vote import majority_merge

# GLOBAL CONFIG 
ENTITY = 
PROJECT_ARTIFACT = "TRACE_RECONSTRUCTION"
DOWNLOAD_DIR = "./downloaded_artifact"

# HELPERS 
def load_dataset(artifact_name):
    wandb.login()
    api = wandb.Api()
    artifact = api.artifact(f"{ENTITY}/{PROJECT_ARTIFACT}/{artifact_name}:latest", type="dataset")
    artifact_dir = artifact.download(DOWNLOAD_DIR)

    # Read raw clusters
    with open(os.path.join(artifact_dir, "reads.txt")) as f:
        reads_lines = [l.strip() for l in f]
    with open(os.path.join(artifact_dir, "ground_truth.txt")) as f:
        gt_lines = [l.strip() for l in f]

    clusters, current = [], []
    for line in reads_lines:
        if line == "===============================":
            if current:
                clusters.append(current)
                current = []
        else:
            current.append(line)
    if current:
        clusters.append(current)
    assert len(clusters) == len(gt_lines)
    return list(zip(range(len(clusters)), clusters, gt_lines))


def generate_temp_evyat_file(reads, gt, folder):
    os.makedirs(folder, exist_ok=True)
    path = os.path.join(folder, "evyat.txt")
    with open(path, 'w') as f:
        f.write(gt + "\n****\n")
        f.writelines(r + "\n" for r in reads)


def read_evyat(path):
    if not os.path.exists(path):
        return [], [], []
    with open(path) as f:
        lines = [l.strip() for l in f if l.strip()]
    gts, preds, dists = [], [], []
    i = 0
    while i < len(lines):
        if lines[i].startswith("Cluster Num") and i+3 < len(lines):
            gts.append(lines[i+1]); preds.append(lines[i+2])
            d = lines[i+3].split(":")[1].strip() if lines[i+3].startswith("Distance:") else -1
            dists.append(int(d))
            i += 4
        else:
            i += 1
    return gts, preds, dists

# ALGORITHM WRAPPERS 
class BMALA:
    def __init__(self, temp_dir):
        self.base   = temp_dir
        # assume cwd is /TReconLM
        self.binary = os.path.join(os.getcwd(), "src", "eval_pkg", "BMALA")
        # make sure it is executable on startup
        os.chmod(self.binary, 0o755)

    def inference(self, reads, gt, idx):
        folder = os.path.join(self.base, f"cluster_{idx}")
        generate_temp_evyat_file(reads, gt, folder)
        cmd = f"{self.binary} {folder}/evyat.txt {folder} > {folder}/out.txt"
        subprocess.run(cmd, shell=True, check=True)
        _, succ, _ = read_evyat(os.path.join(folder, 'output-results-success.txt'))
        pred = succ[0] if succ else read_evyat(os.path.join(folder, 'output-results-fail.txt'))[1][0]
        subprocess.run(f"rm -rf {folder}", shell=True, check=True)
        return pred


class Iterative(BMALA):  
    def __init__(self, temp_dir):
        super().__init__(temp_dir)
        # override to point at the Iterative binary
        self.binary = os.path.join(os.getcwd(), "src", "eval_pkg", "Iterative")
        os.chmod(self.binary, 0o755)

    def inference(self, reads, gt, idx):
        folder = os.path.join(self.base, f"cluster_{idx}")
        generate_temp_evyat_file(reads, gt, folder)
        cmd = f"{self.binary} {folder}/evyat.txt {folder} > {folder}/out.txt"
        subprocess.run(cmd, shell=True, check=True)
        _, succ, _ = read_evyat(os.path.join(folder, 'output-results-success.txt'))
        pred = succ[0] if succ else read_evyat(os.path.join(folder, 'output-results-fail.txt'))[1][0]
        subprocess.run(f"rm -rf {folder}", shell=True, check=True)
        return pred

class MuscleAlgorithm:
    def __init__(self, temp_dir):
        self.temp   = temp_dir
        self.binary = os.path.join(os.getcwd(), "src", "eval_pkg", "muscle")
        os.chmod(self.binary, 0o755)

    def inference(self, reads, idx, gt=None):
        inp = os.path.join(self.temp, f"in_{idx}.fasta")
        out = os.path.join(self.temp, f"out_{idx}.fasta")
        create_fasta_file(reads, 'obs', inp)

        # run muscle, silencing its output
        subprocess.run(
            [self.binary, "-align", inp, "-output", out],
            stdout=DEVNULL, stderr=DEVNULL, check=True
        )

        seqs = read_fasta(out)
        pred = majority_merge(seqs, weight=0.4)  

        os.remove(inp)
        os.remove(out)
        return pred


class TrellisBMAAlgorithm:
    def __init__(self, temp_dir, P_INS=0.055, P_DEL=0.055, P_SUB=0.055, k=0):
        self.base  = temp_dir
        self.k     = k
        self.P_INS = P_INS
        self.P_DEL = P_DEL
        self.P_SUB = P_SUB

    def _select_beta_parameters(self, cluster_size):
        # clamp cluster size to [2,10]
        size = max(2, min(cluster_size, 10))
        k = self.k

        # parameter sets by cluster size
        if size in [2, 3]:
            beta_b, beta_e, beta_i = 0.0, 0.1, 0.5
        elif size in [4, 5]:
            beta_b, beta_e, beta_i = 0.0, 1.0, 0.1
        elif size in [6, 7]:
            beta_b, beta_e, beta_i = 0.0, 0.5, 0.1
        elif size in [8, 9]:
            beta_b, beta_e, beta_i = 0.0, 0.5, 0.5
        else:  # size == 10
            beta_b, beta_e, beta_i = 0.0, 0.5, 0.0

        return {
            'beta_b': beta_b,
            'beta_e': beta_e,
            'beta_i': beta_i,
            'P_INS': self.P_INS, #+ k * 0.005,
            'P_DEL': self.P_DEL, # + k * 0.005,
            'P_SUB': self.P_SUB, # + k * 0.005,
        }

    def select_beta(self, cluster_size):

        params = self._select_beta_parameters(cluster_size)
        return TrellisBMAParams(**params)

    def inference(self, reads, gt, idx):

        params = self.select_beta(len(reads))
        _, pred = compute_trellis_bma_estimation(reads, gt, params)
        return pred


class VSAlgorithm:
    def __init__(self, P_SUB=0.055, k=0):
        self.cfg = {'gamma': 0.75, 'l': 5, 'r': 2, 'P_SUB': P_SUB + k * 0.005}

    def inference(self, reads, gt=None, idx=None):
        return alg(len(reads), reads, self.cfg['l'], (1+self.cfg['P_SUB'])/2, self.cfg['r'], self.cfg['gamma'], gt)

# map names to classes
ALGS = {
    'bmala': BMALA,
    'itr': Iterative,
    'muscle': MuscleAlgorithm,
    'trellisbma': TrellisBMAAlgorithm,
    'vs': VSAlgorithm,
}

ERROR_PROFILES = {
    'default':  { 'P_INS': 0.055,  'P_DEL': 0.055,  'P_SUB': 0.055 },
    'microsoft':{ 'P_INS': 0.017,  'P_DEL': 0.02,   'P_SUB': 0.022 },
    'noisy':    { 'P_INS': 0.057,  'P_DEL': 0.06,   'P_SUB': 0.026 },
}


def process_example(args):
    idx, reads, gt, alg_name, alg_inst = args
    start = time.time()
    pred = alg_inst.inference(reads, gt, idx)
    elapsed = time.time() - start
    return {
        'idx': idx,
        'ground_truth': gt,
        'reconstructed': pred,
        'hamming_distance': hamming_distance_postprocessed(gt, pred),
        'levenshtein_distance': levenshtein_distance(gt, pred)/len(gt),
        'time_taken': elapsed,
        'num_reads': len(reads),
    }



# MAIN 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--alg', choices=ALGS.keys(), required=True)
    parser.add_argument(
        '--artifact',
        default="test_dataset_seed34721_gl110_bs1500_ds50000",
        help="Which W&B artifact to evaluate"
    )
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--project', type=str, default='Timing')

    parser.add_argument(
        '--error-profile',
        choices=['default','microsoft','noisy'],
        default='default',
        help="Select insertion/deletion/substitution rates"
    )
    parser.add_argument(
        '--sweep',
        action='store_true',
        help='If set, run k=0..max-k instead of a single run'
    )
    parser.add_argument(
        '--max-k',
        type=int,
        default=10,
        help='When --sweep, iterate k from 0 to this (inclusive)'
    )
    parser.add_argument(
        '--subset',
        action='store_true',
        help='If set, only use the first 20% of the dataset'
    )

    args = parser.parse_args()

    # pick the rates dictionary based on the flag
    rates = ERROR_PROFILES[args.error_profile]

    # if sweep, run k=0..max_k; otherwise exactly one pass (k=None)
    sweep_range = range(args.max_k + 1) if args.sweep else [None]
    base_seed   = 34721

    for k in sweep_range:
        seed = base_seed #+ (k or 0) depending if different seed for each k

        if args.sweep:
            art = f"sweep{k}_seed{seed}_gl110_bs1500_ds5000"
            run_name = f"{args.alg}_sweep{k}_seed{seed}"
        else:
            art = args.artifact
            run_name = f"{args.alg}_{args.artifact}"

        # load the dataset (and optionally take only the first 20%)
        dataset = load_dataset(art)
        if args.subset:
            cut = max(1, int(0.2 * len(dataset)))
            print(f"--subset set, using first {cut} / {len(dataset)} examples")
            dataset = dataset[:cut]

        # initialize this run
        wandb.init(entity=ENTITY,
                   project=args.project,
                   name=run_name)

        # prepare all the (idx, reads, gt, alg_name, alg_inst) tuples
        pending = []
        for idx, reads, gt in dataset:
            if args.alg == 'trellisbma':
                alg_inst = TrellisBMAAlgorithm(
                    temp_dir=DOWNLOAD_DIR,
                    P_INS=rates['P_INS'],
                    P_DEL=rates['P_DEL'],
                    P_SUB=rates['P_SUB'],
                    k=(k or 0)
                )
            elif args.alg == 'vs':
                alg_inst = VSAlgorithm(
                    P_SUB=rates['P_SUB'],
                    k=(k or 0)
                )
            else:
                alg_inst = ALGS[args.alg](DOWNLOAD_DIR)

            pending.append((idx, reads, gt, args.alg, alg_inst))

        # run them in parallel and collect results
        results_by_N = defaultdict(list)
        time_taken_all = []
        with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as exe:
            for r in tqdm.tqdm(
                    exe.map(process_example, pending),
                    total=len(pending),
                    desc=f"Clusters k={k}",
                    file=sys.stdout
            ):
                N = r['num_reads']
                results_by_N[N].append(r)
                time_taken_all.append(r['time_taken'])

        # Log per‐N metrics to W&B
        for N, res in results_by_N.items():
            h_vals  = [r['hamming_distance']    for r in res]
            ld_vals = [r['levenshtein_distance'] for r in res]
            wandb.log({
                f"success_rate_N={N}": sum(1 for x in h_vals if x == 0) / len(h_vals),
                f"avg_hamming_N={N}":    np.mean(h_vals),
                f"std_hamming_N={N}":    np.std(h_vals),
                f"avg_levenshtein_N={N}": np.mean(ld_vals),
                f"std_levenshtein_N={N}": np.std(ld_vals),
            })

        # Aggregate across all clusters, average over all N
        all_h = []
        all_ld = []
        for res in results_by_N.values():
            all_h.extend(r['hamming_distance']    for r in res)
            all_ld.extend(r['levenshtein_distance'] for r in res)

        total = len(all_h)
        success = sum(1 for x in all_h if x == 0)

        wandb.log({
            "success_rate_all":       success / total,
            "avg_hamming_all":        np.mean(all_h),
            "std_hamming_all":        np.std(all_h),
            "avg_levenshtein_all":    np.mean(all_ld),
            "std_levenshtein_all":    np.std(all_ld),
        })

        # Timing per example
        wandb.log({
            "avg_time_per_example": np.mean(time_taken_all),
            "std_time_per_example": np.std(time_taken_all),
        })
        wandb.finish()


if __name__ == "__main__":
    main()
