import os
import torch
import pickle
import wandb
import numpy as np
import math
from datetime import datetime
from omegaconf import DictConfig, OmegaConf
import hydra
from collections import defaultdict
from tqdm import tqdm
import torch.distributed as dist
from Levenshtein import distance as levenshtein_distance
from requests.exceptions import ChunkedEncodingError
from urllib3.exceptions import ProtocolError
import time

from src.utils.helper_functions import filter_string
from src.utils.hamming_distance import hamming_distance_postprocessed
from src.eval_pkg.GPT_Inference import GPT_Inference
from src.gpt_pkg.model import GPT, GPTConfig
from src.utils.wandb_utils import wandb_kwargs_via_cfg

def safe_download_artifact(entity, project, artifact_name, max_retries=3):
    for attempt in range(1, max_retries+1):
        try:
            art = wandb.use_artifact(f'{entity}/{project}/{artifact_name}:latest', type='dataset')
            return art.download()
        except (Exception) as e:
            print(f"Attempt {attempt} failed: {e}")
            time.sleep(5 * attempt)
    raise RuntimeError(f"Failed to download {artifact_name}")

def split_list(lst, n):
    k, m = divmod(len(lst), n)
    return [lst[i*k + min(i,m):(i+1)*k + min(i+1,m)] for i in range(n)]

class SequenceEncoder:
    def __init__(self, stoi): self.stoi = stoi
    def __call__(self, seq: str) -> list[int]:
        return [self.stoi.get(ch, self.stoi.get('<unk>',0)) for ch in seq]

def run_one_batch(chunk, start, batch_size, cfg_dict, model, meta, device, ctx):
    stoi, itos = meta['stoi'], meta['itos']
    decoder = lambda l: ''.join(itos[i] for i in l)
    encoder = SequenceEncoder(stoi)

    batch = chunk[start:start+batch_size]
    examples = [decoder(x.tolist()) for x, _ in batch]
    gts      = [gt for _, gt in batch]

    valid = [(ex,gt) for ex,gt in zip(examples,gts) if ':' in ex]
    if not valid:
        return []

    inputs, gts = zip(*valid)
    alignment_sizes = [len(ex.split(':')[0].split('|')) for ex in inputs]

    inf_params = {
        'model': model, 'ctx': ctx,
        **cfg_dict,
        'device': device,
        'stoi': stoi, 'itos': itos,
        'encode': encoder, 'decode': decoder,
    }

    with torch.no_grad(), ctx:
        t0 = time.time()
        out = GPT_Inference(inf_params).inference(
            list(inputs), alignment_size=alignment_sizes
        )
        dt = time.time() - t0
    cands = out['candidate_sequences']

    results = []
    for ex, gt, cand in zip(inputs, gts, cands):
        pred  = filter_string(cand)[:len(gt)]
        reads = ex.split(':',1)[0]
        N     = len(reads.split('|'))
        ham   = hamming_distance_postprocessed(gt, pred)
        lev   = levenshtein_distance(gt, pred) / len(gt)

        print(f"Cluster size {N}")
        print(f"[GT]   {gt}")
        print(f"[CAND] {pred}")
        print(f"[HAM]  {ham}")
        print(f"[LEV]  {lev}")
        per_example_time = dt / len(cands)
        results.append((N, reads, gt, pred, ham, lev, per_example_time))

    return results

@hydra.main(config_path='hydra/inference_config',
            config_name='inference_config.yaml',
            version_base=None)
def main(cfg: DictConfig):
    # Initialize distributed 
    dist.init_process_group(backend='nccl')
    rank       = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)
    device = f'cuda:{rank}'

    # Only rank 0 initializes W&B 
    if rank == 0:
        run_cfg  = wandb_kwargs_via_cfg(cfg)
        now      = datetime.now().strftime("%Y%m%d_%H%M%S")
        suffix   = "_sweep" if cfg.data.get("sweep", False) else ""
        base     = f"TReconLM_inference_{now}{suffix}"
        run_name = f"{base}_{cfg.experiment}" if cfg.experiment else base
        out_dir = os.path.join(cfg.general.results_path,
                               'model_evaluation',
                               cfg.wandb.wandb_project,
                               run_name)
        os.makedirs(out_dir, exist_ok=True)
        wandb.init(project=cfg.wandb.wandb_project,
                   entity=cfg.wandb.wandb_entity,
                   name=run_name,
                   config=run_cfg,
                   dir=out_dir)

    # Load model on each GPU 
    ckpt       = torch.load(cfg.model.checkpoint_path, map_location='cpu')
    model_args = ckpt['model_args']
    state_dict = ckpt['model']
    for k in list(state_dict):
        if k.startswith('_orig_mod.'):
            state_dict[k[len('_orig_mod.'):]] = state_dict.pop(k)

    model = GPT(GPTConfig(**model_args)).half().to(device).eval()
    model.load_state_dict(state_dict, strict=False)
    ctx = torch.amp.autocast('cuda',
            dtype=torch.bfloat16 if torch.cuda.is_bf16_supported()
                             else torch.float16)

    # Load vocab metadata 
    script_dir   = os.path.dirname(__file__)
    data_pkg_dir = os.path.join(script_dir, 'data_pkg')
    meta_path    = os.path.join(data_pkg_dir,
                                f"meta_{cfg.data.sequence_type}.pkl")
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)

    # Loop over sweep or single artifact 
    ks = list(range(11)) if cfg.data.get("sweep", False) else [None]
    for k in ks:
        if k is not None:
            seed     = cfg.data.test_seed + k
            art_name = (f"sweep{k}_seed{seed}_gl"
                        f"{cfg.data.ground_truth_length}_bs"
                        f"{cfg.data.block_size}_ds"
                        f"{cfg.data.test_dataset_size}_"
                        f"fixedN{cfg.data.observation_size}")
        else:
            art_name = cfg.data.artifact_name

        # rank 0 downloads, then broadcast art_dir
        if rank == 0:
            art_dir = safe_download_artifact(
                cfg.wandb.wandb_entity,
                cfg.data.data_project,
                art_name
            )
        else:
            art_dir = None

        art_dir_list = [art_dir]
        dist.broadcast_object_list(art_dir_list, src=0)
        art_dir = art_dir_list[0]

        # Load and sort test examples by length
        x_test = torch.load(os.path.join(art_dir, 'test_x.pt'),
                            map_location='cpu')
        gt_file = ('ground_truth_cleaned.txt'
                   if cfg.data.get("cleaned", False)
                   else 'ground_truth.txt')
        with open(os.path.join(art_dir, gt_file)) as f:
            gts = [l.strip() for l in f]
        assert len(x_test) == len(gts)

        itos = meta['itos']
        data_with_len = []
        for x_tensor, gt in zip(x_test, gts):
            s = ''.join(itos[i] for i in x_tensor.tolist())
            prefix = s.split(':',1)[0]
            data_with_len.append((x_tensor, gt, len(prefix)))
        data_with_len.sort(key=lambda t: t[2])
        all_data = [(x, gt) for x, gt, _ in data_with_len]

        # Shard among GPUs 
        chunks   = split_list(all_data, world_size)
        my_chunk = chunks[rank]

        sampling_dict = OmegaConf.to_container(cfg.model.sampling,
                                               resolve=True)
        sampling_dict.update({
            'block_size':          cfg.data.block_size,
            'target_type':         cfg.data.target_type,
            'ground_truth_length': cfg.data.ground_truth_length,
            'greedy':              cfg.model.sampling.strategy=='greedy'
        })

        # Inference on each rank
        batch_size  = cfg.data.batch_size
        num_batches = math.ceil(len(my_chunk) / batch_size)
        local_results = []
        batch_times = []


        for batch_idx in tqdm(range(num_batches), desc=f"Rank {rank}"):
            start = batch_idx * batch_size
            res   = run_one_batch(my_chunk, start, batch_size,
                                  sampling_dict, model, meta,
                                  device, ctx)
            local_results.extend(res)

        # Gather all local_results on rank 0 
        gathered = [None for _ in range(world_size)]
        dist.all_gather_object(gathered, local_results)

        if rank == 0:
            all_results = [r for shard in gathered for r in shard]
            n_ex = len(all_results)

            # compute global metrics
            h_vals = np.array([r[4] for r in all_results])
            l_vals = np.array([r[5] for r in all_results])
            batch_times = np.array([r[6] for r in all_results])

            if k is not None:
                success_count = sum(1 for (N, reads, gt, pred, ham, lev, dt) in all_results if ham == 0)
                wandb.log({
                    f"avg_hamming_k={k}":     float(h_vals.mean()),
                    f"avg_levenshtein_k={k}":  float(l_vals.mean()),
                    f"success_rate_k={k}":     success_count / n_ex,
                    f"failure_rate_k={k}":     1 - (success_count / n_ex),
                })
            else:
                # breakdown by N
                count, success = defaultdict(int), defaultdict(int)
                h_per_N, l_per_N = defaultdict(list), defaultdict(list)
                for N, reads, gt, pred, ham, lev, dt in all_results:
                    count[N]   += 1
                    success[N] += (ham==0)
                    h_per_N[N].append(ham)
                    l_per_N[N].append(lev)
                for N in sorted(count):
                    h_arr, l_arr = np.array(h_per_N[N]), np.array(l_per_N[N])
                    wandb.log({
                        f"count_N={N}":           count[N],
                        f"success_rate_N={N}":    success[N]/count[N],
                        f"avg_hamming_N={N}":     float(h_arr.mean()),
                        f"std_hamming_N={N}":     float(h_arr.std()),
                        f"avg_levenshtein_N={N}": float(l_arr.mean()),
                        f"std_levenshtein_N={N}": float(l_arr.std()),
                    })
                wandb.log({
                    'count_all':           n_ex,
                    'success_rate_all':    sum(success.values())/n_ex,
                    'avg_hamming_all':     float(h_vals.mean()),
                    'avg_levenshtein_all': float(l_vals.mean()),
                    'avg_time_per_example': batch_times.mean(), 
                    'std_time_per_example': batch_times.std(), 

                })

    # Cleanup
    if rank == 0:
        wandb.finish()
    dist.destroy_process_group()

if __name__ == "__main__":
    main()
