import torch
import torch.nn as nn
import numpy as np

from clego_cl.ppcl import infer_ppcl_mix_from_inputs
from clego_cl.task_map import normalize_video_id

# Optional: PPCL modules injected by continual runners.
ppcl_enabled = False
ppcl_state = None  # clego_cl.ppcl.PPCLState
ppcl_mode = "none"  # "train" | "infer" | "none"
ppcl_video_to_task = None  # injected mapping {video_uid: task_id} for oracle

# Optional: L2P modules injected by continual runners.
l2p_enabled = False
l2p_mode = "none"  # "train" | "infer" | "none"
l2p_pool = None  # clego_cl.l2p.L2PPool
l2p_topk = 2
l2p_router_M = 1
l2p_sim_lambda = 0.5
l2p_diversed_selection = True
l2p_batchwise_selection = False


def _l2p_apply_bct(x_source: torch.Tensor, x_target: torch.Tensor):
    """Apply L2P adapters to (B,C,T) inputs by converting to (B,T,C)."""
    global l2p_enabled, l2p_mode, l2p_pool, l2p_router_M
    if not l2p_enabled or l2p_mode == "none" or l2p_pool is None:
        return x_source, x_target
    from skill_benchmark.task_router import extract_r

    src_btc = x_source.transpose(1, 2).contiguous()
    tgt_btc = x_target.transpose(1, 2).contiguous()
    r1 = extract_r(src_btc, M=int(l2p_router_M))
    r2 = extract_r(tgt_btc, M=int(l2p_router_M))
    match1 = l2p_pool.cosine_match(r1)
    match2 = l2p_pool.cosine_match(r2)
    if int(match1.shape[0]) == int(match2.shape[0]):
        match = 0.5 * (match1 + match2)
        sel = l2p_pool.select_topk(match, training=False)
        src_btc = l2p_pool.apply_adapters(src_btc, sel)
        tgt_btc = l2p_pool.apply_adapters(tgt_btc, sel)
        return src_btc.transpose(1, 2).contiguous(), tgt_btc.transpose(1, 2).contiguous()

    from clego_cl.l2p import L2PSelection

    B1 = int(match1.shape[0])
    B2 = int(match2.shape[0])
    pooled = 0.5 * (match1.mean(dim=0) + match2.mean(dim=0))  # [P]
    pooled_match = pooled.view(1, -1).expand(B1 + B2, -1)  # [B1+B2, P]
    sel_all = l2p_pool.select_topk(pooled_match, training=False)

    idx_src = sel_all.indices[:B1]
    idx_tgt = sel_all.indices[B1 : B1 + B2]
    m_src = match1.gather(1, idx_src)
    m_tgt = match2.gather(1, idx_tgt)
    sel_src = L2PSelection(indices=idx_src, match=m_src)
    sel_tgt = L2PSelection(indices=idx_tgt, match=m_tgt)

    src_btc = l2p_pool.apply_adapters(src_btc, sel_src)
    tgt_btc = l2p_pool.apply_adapters(tgt_btc, sel_tgt)
    return src_btc.transpose(1, 2).contiguous(), tgt_btc.transpose(1, 2).contiguous()

def predict(model, model_dir, results_dir, features_path, vid_list_file,
            feat_suffix, feat_sample_rate, all_sample_rate, epoch,
            actions_dict, device, args, load_model: bool = True):
    # Optional: PPCL modules injected by continual runners.
    global ppcl_enabled, ppcl_state, ppcl_mode
    # collect arguments
    verbose = args.verbose
    use_best_model = args.use_best_model

    # multi-GPU
    if args.multi_gpu and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.eval()

    with torch.no_grad():
        model.to(device)
        if load_model:
            if use_best_model == 'source':
                model.load_state_dict(
                    torch.load(model_dir + "/acc_best_source.model"))
                print("load best source model")
            elif use_best_model == 'target':
                model.load_state_dict(
                    torch.load(model_dir + "/acc_best_target.model"))
                print("load best target model")
            else:
                model.load_state_dict(
                    torch.load(model_dir + "/epoch-" + str(epoch) + ".model"))
                print("load epoch-" + str(epoch) + " model")

        list_of_videos = []
        if isinstance(vid_list_file, str):
            vid_list_file = [vid_list_file]
        # print("file_list:", vid_list_file)
        # print("feat_suffix:", feat_suffix)
        assert len(feat_suffix) == len(vid_list_file)
        for i, file in enumerate(vid_list_file):
            file_ptr = open(file, 'r')
            list_of_examples = file_ptr.read().strip().split('\n')
            file_ptr.close()
            # print(feat_suffix[i])
            list_of_examples = [
                dict(vid=x,
                     feat_file=
                     f"{features_path}{x.split('.')[0]}{feat_suffix[i]}.pt")
                for x in list_of_examples
            ]
            list_of_videos.extend(list_of_examples)
            # print(file, list_of_examples)

        for vid in list_of_videos:
            if verbose:
                print(vid)
            feat_file = vid["feat_file"]
            # print(feat_file)
            features = torch.load(feat_file)
            features = features.transpose(1, 0)
            features = features[:, ::feat_sample_rate]
            features = features[:, ::all_sample_rate]
            # `features` may already be a Tensor; `torch.tensor(tensor)` emits a warning and can copy.
            # `as_tensor` preserves tensors and avoids unnecessary copies, while keeping behavior for numpy/array-likes.
            input_x = torch.as_tensor(features, dtype=torch.float)
            input_x.unsqueeze_(0)
            input_x = input_x.to(device)
            mask = torch.ones_like(input_x)
            # Ensure `input_target` is always defined before optional adapters run.
            # This also allows L2P-only inference to keep its adapted target, instead of being
            # overwritten when PPCL is disabled.
            input_target = input_x
            if l2p_enabled and l2p_mode == "infer":
                input_x, input_target = _l2p_apply_bct(input_x, input_target)
            if (
                ppcl_enabled
                and ppcl_mode == "infer"
                and ppcl_state is not None
                and ppcl_state.router is not None
                and ppcl_state.adapter_bank is not None
            ):
                src_btc = input_x.transpose(1, 2).contiguous()
                src_btc_orig = src_btc
                rt = str(getattr(ppcl_state, "router_type", "subspace")).strip().lower()
                if rt in ("oracle", "ppcl_oracle", "gt"):
                    if ppcl_video_to_task is None:
                        raise RuntimeError("ppcl_router_type=oracle requires ppcl_video_to_task to be injected by continual runner.")
                    uid = normalize_video_id(str(vid.get("vid", "")))
                    tid = ppcl_video_to_task.get(uid, None)
                    # If task id is missing or adapter not available, skip PPCL for this video.
                    if tid is None or (not ppcl_state.adapter_bank.has_task(int(tid))):
                        input_target = input_x
                    else:
                        gt = torch.tensor([int(tid)], device=src_btc.device, dtype=torch.long)
                        mix = infer_ppcl_mix_from_inputs(
                            router=ppcl_state.router,
                            router_type=ppcl_state.router_type,
                            x1=src_btc,
                            x2=None,
                            M=int(ppcl_state.router_M),
                            topL=int(ppcl_state.topL),
                            gamma=float(ppcl_state.gamma),
                            gt_task_ids=gt,
                        )
                        src_btc = ppcl_state.adapter_bank.forward_mixture(src_btc, mix)
                        if bool(getattr(ppcl_state, "apply_to_target", True)):
                            tgt_btc = ppcl_state.adapter_bank.forward_mixture(src_btc_orig, mix)
                        else:
                            tgt_btc = src_btc_orig
                        input_x = src_btc.transpose(1, 2).contiguous()
                        input_target = tgt_btc.transpose(1, 2).contiguous()
                else:
                    mix = infer_ppcl_mix_from_inputs(
                        router=ppcl_state.router,
                        router_type=ppcl_state.router_type,
                        x1=src_btc,
                        x2=None,
                        M=int(ppcl_state.router_M),
                        topL=int(ppcl_state.topL),
                        gamma=float(ppcl_state.gamma),
                    )
                    src_btc = ppcl_state.adapter_bank.forward_mixture(src_btc, mix)
                    if bool(getattr(ppcl_state, "apply_to_target", True)):
                        tgt_btc = ppcl_state.adapter_bank.forward_mixture(src_btc_orig, mix)
                    else:
                        tgt_btc = src_btc_orig
                    input_x = src_btc.transpose(1, 2).contiguous()
                    input_target = tgt_btc.transpose(1, 2).contiguous()
            predictions, _, _, _, _, _, _, _, _, _, _, _, _, _ = model(
                input_x, input_target, mask, mask, [0, 0], reverse=False)
            _, predicted = torch.max(predictions[:, -1, :, :].data, 1)
            predicted = predicted.squeeze()
            recognition = []
            # print(all_sample_rate,feat_sample_rate)
            for i in range(predicted.size(0)):
                recognition = np.concatenate((recognition, [
                    list(actions_dict.keys())[list(
                        actions_dict.values()).index(predicted[i].item())]
                ] * all_sample_rate * feat_sample_rate))
            f_name = vid["vid"].split('/')[-1].split('.')[0]
            f_ptr = open(results_dir + "/" + f_name, "w")
            f_ptr.write("### Frame level recognition: ###\n")
            f_ptr.write(' '.join(recognition))
            f_ptr.close()
