"""
Created new inference script for DNAformer for convenience. 

"""
import os
import time
import requests
from requests.exceptions import ChunkedEncodingError
from urllib3.exceptions import ProtocolError
import torch.distributed as dist
import numpy as np

import wandb
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from data_loader_IDS import PrecomputedDNAData, collate_dna 
from helper import save_results, evaluate_and_log
from datetime import datetime
import pickle

def safe_download_artifact(entity, project, artifact_name, max_retries=3):
    for attempt in range(1, max_retries + 1):
        try:
            artifact = wandb.use_artifact(
                f'{entity}/{project}/{artifact_name}:latest',
                type='dataset'
            )
            return artifact.download()
        except (requests.exceptions.RequestException,
                ChunkedEncodingError,
                ProtocolError) as e:
            print(f"Attempt {attempt} failed: {e}")
            time.sleep(5 * attempt)
    raise RuntimeError(f"Failed to download {artifact_name} after {max_retries} attempts.")

def run_inference(config, model):
    # load model checkpoint 
    ckpt   = torch.load(config.pretrained_path, map_location=config.device)
    raw_sd = ckpt['model_state_dict']

    # strip unwanted prefixes
    unwanted = "module._orig_mod."
    clean_sd = {
        (k[len(unwanted):] if k.startswith(unwanted) else k): v
        for k, v in raw_sd.items()
    }
    model.load_state_dict(clean_sd, strict=False)
    model.to(config.device).eval()

    # load vocab metadata 
    script_dir  = os.path.dirname(os.path.abspath(__file__))
    repo_path   = os.path.dirname(os.path.dirname(script_dir))
    data_pkg_dir= os.path.join(repo_path, 'src', 'data_pkg')
    meta_path   = os.path.join(data_pkg_dir, 'meta_nuc.pkl')
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)

    # shared stuff: timestamp & raw config dict
    now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    raw_cfg = {
        k: v for k, v in vars(config).items()
        if not k.startswith("__") and not callable(v)
    }

    # determine sweep indices 
    # Requires config.sweep: bool, and config.test_seed: int
    ks = list(range(11)) if getattr(config, "sweep", False) else [None]

    for k in ks:
        # build per-run artifact name 
        if k is not None:
            seed = config.test_seed + k
            artifact_name = (
                f"sweep{k}_seed{seed}"
                f"_gl{config.label_length}"
                f"_bs1500"
                f"_ds5000"
            )
        else:
            artifact_name = config.test_artifact_name

        # build unique run-name & output directory 
        run_name = f"DNAformer_inference_{now_str}_gl{config.label_length}"
        if k is not None:
            run_name += f"_sweep_k{k}"
        out_dir_k = os.path.join(config.out_dir, config.project, run_name)
        os.makedirs(out_dir_k, exist_ok=True)

        # start a fresh W&B run 
        run = wandb.init(
            project = config.project,
            entity  = config.entity,
            dir     = out_dir_k,
            name    = run_name,
            resume  = False,
            config  = raw_cfg,
        )
        if k is not None:
            wandb.config.update({"sweep_index": k}, allow_val_change=True)

        # download this run’s test data 
        art_dir = safe_download_artifact(
            config.entity,
            config.test_project,
            artifact_name
        )

        # load test examples 
        x_test = torch.load(os.path.join(art_dir, 'test_x.pt'),
                             map_location='cpu')
        with open(os.path.join(art_dir, 'ground_truth.txt')) as f:
            ground_truths = [l.strip() for l in f]

        #build DataLoader 
        ds = PrecomputedDNAData(x_test, ground_truths, config, meta)
        loader = DataLoader(
            ds,
            batch_size   = config.test_batch_size,
            shuffle      = False,
            num_workers  = config.num_workers,
            pin_memory   = True,
            collate_fn   = lambda b: collate_dna(
                b,
                siamese=(config.model_config=='siamese')
            )
        )

        # inference loop 
        all_results = []
        pbar = tqdm(loader, desc=f"Inference (k={k})", total=len(loader),
                    leave=False, dynamic_ncols=True)
        with torch.inference_mode():
            for batch in pbar:
                if config.model_config == 'single':
                    inp = batch['model_input'].to(config.device)
                else:
                    left  = batch['model_input']
                    right = batch['model_input_right']
                    inp   = torch.cat([left, right], dim=0).to(config.device)

                t0 = time.time()
                out = model(inp)
                probs = torch.softmax(out['pred'], dim=1)
                dt = time.time() - t0
                per_example_time = dt / inp.size(0)

                batch_results = save_results(config, batch, probs)
                for res in batch_results:
                    res["inf_time"] = per_example_time
                all_results.extend(batch_results)

        #evaluate & log
        evaluate_and_log(
            all_results,
            out_dir_k,
            log_to_wandb=True
        )

        #finish this W&B run
        wandb.finish()

    print("All sweep runs completed.")
