import os
import numpy as np
import random
import torch
import re

from clip import clip
from torch.distributions.multivariate_normal import MultivariateNormal
from typing import List, Tuple, Dict, Optional
import json
import os.path as osp
import matplotlib.pyplot as plt

import torch.nn.functional as F


def cosine_schedule_warmup(total_step, value, final_value=0, warmup_step=0, warmup_value=0):
    if warmup_step > 0:
        warmup_schedule = np.linspace(warmup_value, value, warmup_step+2)[1:-1]
    else:
        warmup_schedule = np.array([])
    steps = np.arange(total_step - warmup_step)
    schedule = final_value + 0.5 * (value-final_value) * (1+np.cos(np.pi * steps / len(steps)))
    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == total_step
    return schedule

class build_cosine_scheduler:
    def __init__(self, optimizer, lr, total_step, lr_warmup_step=0):
        init_lr = 0
        final_lr = lr * 1e-3
        self.lrs = cosine_schedule_warmup(total_step, lr, final_lr, lr_warmup_step, init_lr)
        self.optimizer = optimizer

    def step(self,idx):
        lr = self.lrs[idx]
        for i, param_group in enumerate(self.optimizer.param_groups):
            param_group["lr"]= lr
        self.lr=lr


def get_transform(cfg):
    return clip._transform(cfg.input_size[0])


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def cosine_loss_3d(pred, target):
    pred = pred / pred.norm(dim=-1, keepdim=True)
    target = target / target.norm(dim=-1, keepdim=True)
    loss = torch.sum(pred*target, dim=2)
    loss = 1 - torch.mean(loss)
    return loss

def cal_MTIL_metrics(acc_list):
    acc_list = np.array(acc_list)
    acc_list *= 100
    avg = acc_list.mean(axis=0)
    last = np.array(acc_list[-1, :])
    transfer = np.array([np.mean([acc_list[j, i] for j in range(i)]) for i in range(1, acc_list.shape[1])])
    g = lambda x: np.around(x.mean(), decimals=1) if len(x) > 0 else -1
    f = lambda x: [np.around(i, decimals=1) for i in x]
    return {"transfer": {"transfer": f(transfer)}, "avg": {"avg": f(avg)}, "last": {"last": f(last)}, 
            "results_mean": {"transfer": g(transfer), "avg": g(avg), "last": g(last)}}


def select_task_by_gaussian_center(
    image_features: torch.Tensor,
    means: List[torch.Tensor],
    covars: List[torch.Tensor],
    task_learnt: int,
    prompt_processor,
    batchwise_prompt: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]:

    dists = [MultivariateNormal(means[i], covars[i]) for i in range(task_learnt)]

    log_probs = torch.vstack([dist.log_prob(image_features) for dist in dists]).t()

    topk, raw_indices = log_probs.topk(k=1, dim=1)  # [bs, 1]

    exp_part = topk.squeeze(1) / image_features.size(1) - 1.0
    batch_weight = torch.sigmoid(exp_part)  # [bs]

    text_batch_weight = batch_weight.mean(dim=0, keepdim=True).repeat(prompt_processor.cur_n_cls)

    indices = raw_indices
    if batchwise_prompt:
        prompt_id, id_counts = torch.unique(raw_indices, return_counts=True, sorted=True)
        _, major_idx = torch.topk(id_counts, k=1)
        indices = prompt_id[major_idx].unsqueeze(0)  # [1, 1]
        domain_pred_task_id = indices.item()
    else:
        domain_pred_task_id = indices.item()

    return raw_indices, indices, batch_weight, text_batch_weight, domain_pred_task_id

def load_description_dicts(dataset_names: List[str], description_dir: str = "description") -> Dict[str, dict]:
    description_dicts = {}

    for dataset_name in dataset_names:
        json_path = os.path.join(description_dir, f"{dataset_name}.json")
        if not os.path.exists(json_path):
            print(f"[Warning] Description file not found: {json_path}, skipped.")
            continue

        with open(json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            description_dicts[dataset_name] = data

    return description_dicts

def count_description_keys(description_dicts):
    result = []
    for task_name, cls_dict in description_dicts.items():
        task_counts = []
        for cls_name, cls_info in cls_dict.items():
            num_keys = len(cls_info.get("descriptions", {}))
            task_counts.append(num_keys)
        result.append(task_counts)
    return result

def average_logits_per_class(logits: torch.Tensor, class_ids_per_task: list):
    bs = logits.size(0)
    num_classes = len(class_ids_per_task)
    
    default_logits = []
    external_logits_avg = []

    start_idx = 0
    for n_prompts in class_ids_per_task:
        class_logits = logits[:, start_idx:start_idx + n_prompts]  # [bs, n_prompts]
        start_idx += n_prompts

        default_logits.append(class_logits[:, 0:1])  # [bs, 1]

        if n_prompts > 1:
            avg_ext = class_logits[:, 1:].mean(dim=1, keepdim=True)  # [bs, 1]
        else:
            avg_ext = torch.zeros(bs, 1, device=logits.device, dtype=logits.dtype)

        external_logits_avg.append(avg_ext)

    default_logits = torch.cat(default_logits, dim=1)
    external_logits_avg = torch.cat(external_logits_avg, dim=1)

    return default_logits, external_logits_avg

def average_logits_per_description(image_features: torch.Tensor, 
                                   text_features: torch.Tensor, 
                                   n_descriptions: int = 5) -> torch.Tensor:

    B = image_features.shape[0] // n_descriptions
    n_cls = text_features.shape[0] // n_descriptions

    # reshape image_features -> [B, n_descriptions, D]
    image_features = image_features.view(B, n_descriptions, -1)

    # reshape text_features -> [n_cls, n_descriptions, D]
    text_features = text_features.view(n_cls, n_descriptions, -1)

    # compute logits for each description: [B, n_cls, n_descriptions]
    logits = torch.einsum('bid,cid->bci', image_features, text_features)

    # average over descriptions dimension
    logits_avg = logits.mean(dim=2)  # [B, n_cls]

    return logits_avg

def check_optimizer_params(model, verbose: bool = True):

    params = []
    if verbose:
        print("=== Parameters in optimizer ===")
    for name, p in model.named_parameters():
        if p.requires_grad:
            params.append(name)
            if verbose:
                print(f"{name:50s} {tuple(p.shape)}")
    if verbose and not params:
        print("⚠️ No parameters require grad!")
    return params

def check_gradients(model, verbose: bool = True):

    grads = {}
    if verbose:
        print("=== Gradients after backward ===")
    for name, p in model.named_parameters():
        if p.requires_grad:
            grad_none = p.grad is None
            grads[name] = None if grad_none else p.grad.detach().cpu()
            if verbose:
                grad_info = "None" if grad_none else f"mean={p.grad.mean().item():.6f}"
                print(f"{name:50s} {grad_info}")
    return grads

def check_optimizer_contents(optimizer, model=None, verbose: bool = True):

    id_to_name = {}
    if model is not None:
        for name, p in model.named_parameters():
            id_to_name[id(p)] = name

    if verbose:
        print("=== Optimizer contents ===")

    groups_info = []
    for i, group in enumerate(optimizer.param_groups):
        group_info = {
            "group_id": i,
            "lr": group.get("lr", None),
            "params": []
        }
        if verbose:
            print(f"Optimizer group {i}, lr={group.get('lr', None)}")

        for p in group["params"]:
            pid = id(p)
            info = {
                "name": id_to_name.get(pid, None),
                "shape": tuple(p.shape),
                "requires_grad": p.requires_grad,
                "id": pid
            }
            group_info["params"].append(info)

            if verbose:
                if info["name"] is not None:
                    print(f"  {info['name']:40s} {info['shape']} requires_grad={info['requires_grad']}")
                else:
                    print(f"  <no name> {info['shape']} requires_grad={info['requires_grad']} (id={pid})")

        groups_info.append(group_info)

    return groups_info

def set_trainable_params(model, names_to_update=None):

    if names_to_update is None:
        names_to_update = []

    for name, param in model.named_parameters():
        update_flag = any(n in name for n in names_to_update)
        param.requires_grad_(update_flag)

    # Double check
    enabled = set()
    for name, param in model.named_parameters():
        if param.requires_grad:
            enabled.add(name)

    # print(f"Parameters to be updated: {enabled}")
    return model, enabled

def log_trainable_params(model, cfg, verbose: bool = True):

    enabled = set()
    for name, param in model.named_parameters():
        if param.requires_grad:
            enabled.add(name)

    para_log = f"Parameters to be updated: {enabled}"

    if verbose:
        print(para_log)

    output_file = osp.join(cfg.log_path, 'output.txt')
    with open(output_file, 'a') as f:
        f.write(para_log + '\n')
    
def visualize_logits(
    logits: torch.Tensor,
    default_logits: torch.Tensor,
    external_logits_avg: torch.Tensor,
    prompts_per_class: list,
    inputs,
    test_cur_train_task_id,
    test_cur_test_task_id,
    log_path,
    targets: torch.Tensor,
    class_names: list,
    batch_id: int,
    max_save: int = 4,
    only_save_wrong: bool = False
):

    save_dir = os.path.join(log_path, "logits", f"{test_cur_train_task_id}-{test_cur_test_task_id}")
    os.makedirs(save_dir, exist_ok=True)

    bs = logits.size(0)
    n_cls = len(prompts_per_class)
    save_num = min(bs, max_save)

    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3,1,1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3,1,1)

    for i in range(save_num):
        gt_label = int(targets[i].item())
        default_pred = int(default_logits[i].argmax())
        external_pred = int(external_logits_avg[i].argmax())
        is_wrong = (default_pred != gt_label) or (external_pred != gt_label)
        if only_save_wrong and not is_wrong:
            continue

        base_width = 12        
        width_per_class = 0.1   
        fig_width = max(base_width, n_cls * width_per_class)

        fig, axes = plt.subplots(3, 1, figsize=(fig_width, 8), constrained_layout=True)

        dl = default_logits[i].detach().cpu().numpy()
        pred_idx = int(dl.argmax())
        pred_name = class_names[pred_idx]
        axes[0].bar(np.arange(n_cls), dl, color='skyblue')
        axes[0].set_title(f"Default Prompts | GT={gt_label}: {class_names[gt_label]} | Pred={pred_idx}: {pred_name}")
        axes[0].set_ylabel("Logit")
        axes[0].set_xticks(np.arange(n_cls))
        if n_cls > 150:
            axes[0].set_xticklabels(np.arange(n_cls), rotation=90, fontsize=6)
        else:
            axes[0].set_xticklabels(class_names, rotation=90, fontsize=8)

        el = external_logits_avg[i].detach().cpu().numpy()
        pred_idx = int(el.argmax())
        pred_name = class_names[pred_idx]
        axes[1].bar(np.arange(n_cls), el, color='lightgreen')
        axes[1].set_title(f"External Prompts Avg | GT={gt_label}: {class_names[gt_label]} | Pred={pred_idx}: {pred_name}")
        axes[1].set_ylabel("Logit")
        axes[1].set_xticks(np.arange(n_cls))
        if n_cls > 150:
            axes[1].set_xticklabels(np.arange(n_cls), rotation=90, fontsize=6)
        else:
            axes[1].set_xticklabels(class_names, rotation=90, fontsize=8)

        raw_logits = logits[i].detach().cpu().numpy()
        x_coord = []
        y = []
        tick_pos = []
        tick_labels = []

        cur_x = 0
        cur_idx = 0 
        for c_idx, num in enumerate(prompts_per_class):
            x_coord.extend(range(cur_x, cur_x + num))
            y.extend(raw_logits[cur_idx:cur_idx + num])

            tick_pos.append(cur_x + num // 2)
            tick_labels.append(class_names[c_idx])
            cur_x += num + 1  
            cur_idx += num   
        axes[2].bar(x_coord, y, color='salmon')
        axes[2].set_xticks(tick_pos)
        if n_cls > 150:
            axes[2].set_xticklabels(np.arange(n_cls), rotation=90, fontsize=6)
        else:
            axes[2].set_xticklabels(tick_labels, rotation=90, fontsize=8)

        save_name = os.path.join(save_dir, f"batch{batch_id}_sample{i}.pdf")
        plt.savefig(save_name)
        plt.close(fig)

        if is_wrong and inputs is not None:
            img = inputs[i].detach().cpu()
            img = img * std + mean
            img = img.clamp(0,1)
            img = img.permute(1,2,0).numpy()  # C,H,W -> H,W,C
            save_file = os.path.join(save_dir, f"batch{batch_id}_sample{i}_original_wrong.png")
            plt.imsave(save_file, img)

def build_prompt(desc_text: str, cls_name) -> str:
    desc_text = desc_text.strip().rstrip('.') 

    if desc_text:
        desc_text = desc_text[0].lower() + desc_text[1:]
    
    if re.match(r"^(has|is|lives|symbolizes|appears|looks|often|usually|known)", desc_text, re.IGNORECASE):
        sentence = f"It is {desc_text}."
    else:
        sentence = f"It is {desc_text}."
    
    sentence = sentence[0].upper() + sentence[1:]
    return sentence

def build_default_prompt(task_idx: int, cls_name: str) -> str:
    if task_idx == 0:
        template = f"a photo of a {cls_name}, a type of aircraft."
    elif task_idx == 3:
        template = f"a photo of a {cls_name} texture."
    elif task_idx == 4:
        template = f"a centered satellite photo of {cls_name}."
    elif task_idx == 5:
        template = f"a photo of a {cls_name}, a type of flower."
    elif task_idx == 6:
        template = f"a photo of a {cls_name}, a type of food."
    elif task_idx == 7:
        template = f"a photo of the number: \"{cls_name}\"."
    elif task_idx == 8:
        template = f"a photo of a {cls_name}, a type of pet."
    elif task_idx == 9:
        template = f"a photo of a {cls_name}, a type of car."
    else:
        template = f"a photo of a {cls_name}."

    return template

def fuse_logits(default_logits: torch.Tensor,
                external_logits: torch.Tensor,
                alpha: float = 0.5) -> torch.Tensor:
    final_logits = alpha * default_logits + (1 - alpha) * external_logits
    return final_logits


def build_positive_prototypes_for_images(
    image_class_ids: torch.Tensor,
    class_proto_map: Dict[str, Dict[str, List[int]]]
) -> List[List[int]]:
    B = int(image_class_ids.numel()) 
    positives: List[List[int]] = []
    for b in range(B):
        cid = int(image_class_ids[b].item())
        entry = class_proto_map.get(str(cid)) or class_proto_map.get(cid)
        if entry is None:
            positives.append([])  
        else:
            positives.append(list(entry["proto_ids"]))
    return positives


def l2_normalize(t: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return t / (t.norm(dim=-1, keepdim=True).clamp(min=eps))

def prototype_loss_soft_target(
    img_embeds: torch.Tensor,        # [B, D]
    prototypes: torch.Tensor,        # [K, D]
    positives: List[List[int]],      # length B, each is list of proto ids
    temperature: float = 0.07,
    reduction: str = "mean"          # "mean" or "sum" or "none"
) -> torch.Tensor:
    """
    Treat positives for each sample as a uniform target distribution over positive prototypes.
    Loss = - sum_p target_p * log_softmax(logits)_p
    """
    device = img_embeds.device
    B, D = img_embeds.shape
    K = prototypes.shape[0]

    # normalize
    img = l2_normalize(img_embeds)
    prot = l2_normalize(prototypes)

    # [B, K] similarity logits
    logits = (img @ prot.t()) / temperature   # [B, K]

    # build target distribution
    # if a sample has no positives -> we treat it as all-zero and mask it out
    target = logits.new_zeros((B, K))
    mask_has_pos = torch.zeros(B, dtype=torch.bool, device=device)

    for i, pos in enumerate(positives):
        if pos and len(pos) > 0:
            mask_has_pos[i] = True
            pos_idx = torch.tensor(pos, dtype=torch.long, device=device)
            # uniform over positives
            target[i, pos_idx] = 1.0 / pos_idx.numel()

    # compute log-probs
    logp = F.log_softmax(logits, dim=1)  # [B, K]
    # per-sample loss (sum over K)
    per_sample_loss = -(target * logp).sum(dim=1)  # [B]
    # mask out samples with no positives (set loss 0 or ignore)
    per_sample_loss = per_sample_loss * mask_has_pos.float()

    if reduction == "mean":
        # average only among samples that have positives
        denom = mask_has_pos.float().sum().clamp(min=1.0)
        return per_sample_loss.sum() / denom
    elif reduction == "sum":
        return per_sample_loss.sum()
    else:
        return per_sample_loss  # [B]


def prototype_loss_multipos_infoNCE(
    img_embeds: torch.Tensor,        # [B, D]
    prototypes: torch.Tensor,        # [K, D]
    positives: List[List[int]],      # length B, each is list of proto ids
    temperature: float = 0.07,
    reduction: str = "mean"
) -> torch.Tensor:

    device = img_embeds.device
    B, D = img_embeds.shape
    K = prototypes.shape[0]

    img = l2_normalize(img_embeds)
    prot = l2_normalize(prototypes)

    prot = prot.to(img.dtype)
    logits = (img @ prot.t()) / temperature   # [B, K]

    if torch.rand(1).item() < 0.01:  
        pos_sims = []
        neg_sims = []
        for i, pos in enumerate(positives):
            if pos:
                pos_sims.append(logits[i, pos].mean().item())
                neg_sims.append(logits[i, :].mean().item())
        if pos_sims:
            print("pos_sim:", np.mean(pos_sims), "neg_sim:", np.mean(neg_sims))

    # compute logsumexp(all)
    lse_all = torch.logsumexp(logits, dim=1)  # [B]

    # compute logsumexp(pos) for each sample (if no positives -> we will mask)
    lse_pos = logits.new_full((B,), float('-inf'))  # initialize -inf so exp(-inf)=0
    mask_has_pos = torch.zeros(B, dtype=torch.bool, device=device)
    for i, pos in enumerate(positives):
        if pos and len(pos) > 0:
            mask_has_pos[i] = True
            pos_idx = torch.tensor(pos, dtype=torch.long, device=device)
            lse_pos[i] = torch.logsumexp(logits[i, pos_idx], dim=0)

    # loss = lse_all - lse_pos
    per_sample_loss = lse_all - lse_pos    # [B], inf if no pos (since lse_pos=-inf)
    # mask out samples without positives (set loss 0)
    per_sample_loss = per_sample_loss * mask_has_pos.float()

    if reduction == "mean":
        denom = mask_has_pos.float().sum().clamp(min=1.0)
        return per_sample_loss.sum() / denom
    elif reduction == "sum":
        return per_sample_loss.sum()
    else:
        return per_sample_loss  # [B]


# -----------------------
# Optional: compute simple hit@1/ hit@k for debugging
# -----------------------
def prototype_hit_rates(
    img_embeds: torch.Tensor,
    prototypes: torch.Tensor,
    positives: List[List[int]],
    k: int = 1
):
    """
    Return hit@1 or hit@k: fraction of samples whose top-k prototypes contain any positive.
    """
    img = l2_normalize(img_embeds)
    prot = l2_normalize(prototypes)
    logits = img @ prot.t()                # [B, K]
    topk = torch.topk(logits, k, dim=1).indices  # [B, k]
    hits = []
    for i, pos in enumerate(positives):
        if pos and len(pos) > 0:
            pos_set = set(pos)
            hit = any(int(idx.item()) in pos_set for idx in topk[i])
            hits.append(1 if hit else 0)
        else:
            # if no positives, we exclude it from metric
            hits.append(None)
    # compute fraction ignoring None
    valid = [h for h in hits if h is not None]
    if len(valid) == 0:
        return 0.0
    return float(sum(valid) / len(valid))


def compute_default_and_external_logits(
    image_features: torch.Tensor,
    image_features_prom: torch.Tensor,
    text_features: torch.Tensor,
    class_ids_per_task: List[int],
    logit_scale
) -> Tuple[torch.Tensor, torch.Tensor]:

    device = image_features.device
    text_features = text_features.to(device)
    image_features_prom = image_features_prom.to(device)

    bs, D_img = image_features.shape
    total_prompts, D_text = text_features.shape

    total_classes = len(class_ids_per_task)
    if sum(class_ids_per_task) != total_prompts:
        raise ValueError(
            f"sum(class_ids_per_task)={sum(class_ids_per_task)} != text_features.shape[0]={total_prompts}."
            " Ensure text_features are ordered class-by-class and class_ids_per_task is correct."
        )

    if D_img != D_text:
        raise ValueError(
            f"image_features dim ({D_img}) != text_features dim ({D_text}). "
            "Please project one to the other's dimension before calling."
        )
    image_features_n = F.normalize(image_features, dim=-1)          # [bs, D]
    image_features_prom_n = F.normalize(image_features_prom, dim=-1)
    text_features_n = F.normalize(text_features, dim=-1)            # [sum(P_i), D]

    cum = [0]
    for p in class_ids_per_task:
        cum.append(cum[-1] + int(p))

    default_text_list = []
    for i in range(total_classes):
        s, e = cum[i], cum[i+1]
        default_text_list.append(text_features_n[s])  # [D]
    default_text = torch.stack(default_text_list, dim=0)   # [total_classes, D]

    default_logits = image_features_n @ default_text.t()   # [bs, total_classes]
    default_logits = logit_scale * default_logits

    external_logits_avg = torch.zeros(bs, total_classes, device=device, dtype=default_logits.dtype)
    for i in range(total_classes):
        s, e = cum[i], cum[i+1]
        P_i = e - s
        N_i = P_i - 1
        if N_i <= 0:
            continue
        ext = text_features_n[s+1:e]            # [N_i, D]
        logits_i = image_features_prom_n @ ext.t()
        external_logits_avg[:, i] = logits_i.mean(dim=1)

    external_logits_avg = logit_scale * external_logits_avg

    return default_logits, external_logits_avg

def visualize_correction(
    default_logits: torch.Tensor,
    external_logits_avg: torch.Tensor,
    prompts_per_class: list,
    inputs,
    test_cur_train_task_id,
    test_cur_test_task_id,
    log_path,
    targets: torch.Tensor,
    class_names: list,
    batch_id: int,
    max_save: int = 4
):

    save_dir = os.path.join(log_path, "logits_correction", f"{test_cur_train_task_id}-{test_cur_test_task_id}")
    os.makedirs(save_dir, exist_ok=True)

    bs = default_logits.size(0)
    save_num = min(bs, max_save)

    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3,1,1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3,1,1)

    for i in range(save_num):
        gt_label = int(targets[i].item())
        default_pred = int(default_logits[i].argmax())
        external_pred = int(external_logits_avg[i].argmax())
        final_logits = 0.7*default_logits[i] + 0.3*external_logits_avg[i]
        final_pred = int(final_logits.argmax())

        if not (default_pred != gt_label and final_pred == gt_label):
            continue

        topk = 5  
        fig, axes = plt.subplots(1, 4, figsize=(20, 5), constrained_layout=True)

        img = inputs[i].detach().cpu()
        img = img * std + mean
        img = img.clamp(0,1)
        img = img.permute(1,2,0).numpy()  # C,H,W -> H,W,C
        axes[0].imshow(img)
        axes[0].axis('off')

        fl = final_logits.detach().cpu().numpy()
        top_idx = fl.argsort()[-topk:][::-1]
        top_vals = fl[top_idx]
        axes[1].bar(np.arange(topk), top_vals, color='skyblue')
        axes[1].set_xticks(np.arange(topk))
        axes[1].set_xticklabels([class_names[j] for j in top_idx], rotation=45, ha='right', fontsize=10)
        for idx, cls in enumerate(top_idx):
            if cls == gt_label:
                axes[1].patches[idx].set_color('lightgreen')
        axes[1].set_title("Final Logits Top5", fontsize=14)
        axes[1].set_ylabel("Logit", fontsize=12)

        dl = default_logits[i].detach().cpu().numpy()
        top_idx = dl.argsort()[-topk:][::-1]
        top_vals = dl[top_idx]
        axes[2].bar(np.arange(topk), top_vals, color='skyblue')
        axes[2].set_xticks(np.arange(topk))
        axes[2].set_xticklabels([class_names[j] for j in top_idx], rotation=45, ha='right', fontsize=10)
        for idx, cls in enumerate(top_idx):
            if cls == gt_label:
                axes[2].patches[idx].set_color('lightgreen')
        axes[2].set_title("Category Logits Top5", fontsize=14)
        axes[2].set_ylabel("Logit", fontsize=12)

        el = external_logits_avg[i].detach().cpu().numpy()
        top_idx = el.argsort()[-topk:][::-1]
        top_vals = el[top_idx]
        axes[3].bar(np.arange(topk), top_vals, color='skyblue')
        axes[3].set_xticks(np.arange(topk))
        axes[3].set_xticklabels([class_names[j] for j in top_idx], rotation=45, ha='right', fontsize=10)
        for idx, cls in enumerate(top_idx):
            if cls == gt_label:
                axes[3].patches[idx].set_color('lightgreen')
        axes[3].set_title("Concept Logits Top5", fontsize=14)
        axes[3].set_ylabel("Logit", fontsize=12)

        save_name = os.path.join(save_dir, f"batch{batch_id}_sample{i}_correction.pdf")
        plt.savefig(save_name)
        plt.close(fig)


def get_task_class_names():
    task_class_names = [
        ['707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', 
                        '767-200', '767-300', '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', 
                        'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', 'CRJ-200', 'CRJ-700', 'CRJ-900', 
                        'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 
                        'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145','Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A/B', 'F/A-18', 'Falcon 2000', 
                        'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 
                        'Metroliner', 'Model B200', 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', 'Tu-154', 'Yak-42'],
        ['face', 'leopard', 'motorbike', 'accordion', 'airplane', 'anchor', 'ant', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 
            'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 
            'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 
            'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 
            'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 
            'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 
            'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'],
        ['apple', 'aquarium fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 
            'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 
            'keyboard', 'lamp', 'lawn mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak tree', 'orange', 'orchid', 'otter', 'palm tree',
            'pear', 'pickup truck', 'pine tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 
            'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 
            'turtle', 'wardrobe', 'whale', 'willow tree', 'wolf', 'woman', 'worm'],
        ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 
            'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 
            'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged'],
        ['Annual Crop Land', 'Forest', 'Herbaceous Vegetation Land', 'Highway or Road', 'Industrial Buildings', 'Pasture Land', 'Permanent Crop Land', 'Residential Buildings', 'River', 'Sea or Lake'],
        ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',
            "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 
            'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 
            'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 
            'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 
            'geranium', 'orange dahlia', 'pink-yellow dahlia', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris',
                'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 
                'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 
                'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily'],
        ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 
         'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',
           'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 
           'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 
           'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 
           'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 
           'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 
           'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'],
        ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'],
        ['abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'bengal', 'birman', 'bombay', 'boxer', 'british_shorthair', 'chihuahua', 'egyptian_mau', 
            'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'maine_coon', 'miniature_pinscher', 'newfoundland', 
            'persian', 'pomeranian', 'pug', 'ragdoll', 'russian_blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'siamese', 'sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 
            'yorkshire_terrier'],
        ['2000 AM General Hummer SUV', '2012 Acura RL Sedan', '2012 Acura TL Sedan', '2008 Acura TL Type-S', '2012 Acura TSX Sedan', '2001 Acura Integra Type R', '2012 Acura ZDX Hatchback', 
            '2012 Aston Martin V8 Vantage Convertible', '2012 Aston Martin V8 Vantage Coupe', '2012 Aston Martin Virage Convertible', '2012 Aston Martin Virage Coupe', '2008 Audi RS 4 Convertible', 
            '2012 Audi A5 Coupe', '2012 Audi TTS Coupe', '2012 Audi R8 Coupe', '1994 Audi V8 Sedan', '1994 Audi 100 Sedan', '1994 Audi 100 Wagon', '2011 Audi TT Hatchback', '2011 Audi S6 Sedan', 
            '2012 Audi S5 Convertible', '2012 Audi S5 Coupe', '2012 Audi S4 Sedan', '2007 Audi S4 Sedan', '2012 Audi TT RS Coupe', '2012 BMW ActiveHybrid 5 Sedan', '2012 BMW 1 Series Convertible', 
            '2012 BMW 1 Series Coupe', '2012 BMW 3 Series Sedan', '2012 BMW 3 Series Wagon', '2007 BMW 6 Series Convertible', '2007 BMW X5 SUV', '2012 BMW X6 SUV', '2012 BMW M3 Coupe', '2010 BMW M5 Sedan', 
            '2010 BMW M6 Convertible', '2012 BMW X3 SUV', '2012 BMW Z4 Convertible', '2012 Bentley Continental Supersports Conv. Convertible', '2009 Bentley Arnage Sedan', '2011 Bentley Mulsanne Sedan', 
            '2012 Bentley Continental GT Coupe', '2007 Bentley Continental GT Coupe', '2007 Bentley Continental Flying Spur Sedan', '2009 Bugatti Veyron 16.4 Convertible', '2009 Bugatti Veyron 16.4 Coupe', 
            '2012 Buick Regal GS', '2007 Buick Rainier SUV', '2012 Buick Verano Sedan', '2012 Buick Enclave SUV', '2012 Cadillac CTS-V Sedan', '2012 Cadillac SRX SUV', '2007 Cadillac Escalade EXT Crew Cab', 
            '2012 Chevrolet Silverado 1500 Hybrid Crew Cab', '2012 Chevrolet Corvette Convertible', '2012 Chevrolet Corvette ZR1', '2007 Chevrolet Corvette Ron Fellows Edition Z06', 
            '2012 Chevrolet Traverse SUV', '2012 Chevrolet Camaro Convertible', '2010 Chevrolet HHR SS', '2007 Chevrolet Impala Sedan', '2012 Chevrolet Tahoe Hybrid SUV', '2012 Chevrolet Sonic Sedan', 
            '2007 Chevrolet Express Cargo Van', '2012 Chevrolet Avalanche Crew Cab', '2010 Chevrolet Cobalt SS', '2010 Chevrolet Malibu Hybrid Sedan', '2009 Chevrolet TrailBlazer SS', 
            '2012 Chevrolet Silverado 2500HD Regular Cab', '2007 Chevrolet Silverado 1500 Classic Extended Cab', '2007 Chevrolet Express Van', '2007 Chevrolet Monte Carlo Coupe', '2007 Chevrolet Malibu Sedan', 
            '2012 Chevrolet Silverado 1500 Extended Cab', '2012 Chevrolet Silverado 1500 Regular Cab', '2009 Chrysler Aspen SUV', '2010 Chrysler Sebring Convertible', '2012 Chrysler Town and Country Minivan', 
            '2010 Chrysler 300 SRT-8', '2008 Chrysler Crossfire Convertible', '2008 Chrysler PT Cruiser Convertible', '2002 Daewoo Nubira Wagon', '2012 Dodge Caliber Wagon', '2007 Dodge Caliber Wagon', 
            '1997 Dodge Caravan Minivan', '2010 Dodge Ram Pickup 3500 Crew Cab', '2009 Dodge Ram Pickup 3500 Quad Cab', '2009 Dodge Sprinter Cargo Van', '2012 Dodge Journey SUV', '2010 Dodge Dakota Crew Cab', 
            '2007 Dodge Dakota Club Cab', '2008 Dodge Magnum Wagon', '2011 Dodge Challenger SRT8', '2012 Dodge Durango SUV', '2007 Dodge Durango SUV', '2012 Dodge Charger Sedan', '2009 Dodge Charger SRT-8', 
            '1998 Eagle Talon Hatchback', '2012 FIAT 500 Abarth', '2012 FIAT 500 Convertible', '2012 Ferrari FF Coupe', '2012 Ferrari California Convertible', '2012 Ferrari 458 Italia Convertible', # 
            '2012 Ferrari 458 Italia Coupe', '2012 Fisker Karma Sedan', '2012 Ford F-450 Super Duty Crew Cab', '2007 Ford Mustang Convertible', '2007 Ford Freestar Minivan', '2009 Ford Expedition EL SUV', 
            '2012 Ford Edge SUV', '2011 Ford Ranger SuperCab', '2006 Ford GT Coupe', '2012 Ford F-150 Regular Cab', '2007 Ford F-150 Regular Cab', '2007 Ford Focus Sedan', '2012 Ford E-Series Wagon Van', 
            '2012 Ford Fiesta Sedan', '2012 GMC Terrain SUV', '2012 GMC Savana Van', '2012 GMC Yukon Hybrid SUV', '2012 GMC Acadia SUV', '2012 GMC Canyon Extended Cab', '1993 Geo Metro Convertible', 
            '2010 HUMMER H3T Crew Cab', '2009 HUMMER H2 SUT Crew Cab', '2012 Honda Odyssey Minivan', '2007 Honda Odyssey Minivan', '2012 Honda Accord Coupe', '2012 Honda Accord Sedan', 
            '2012 Hyundai Veloster Hatchback', '2012 Hyundai Santa Fe SUV', '2012 Hyundai Tucson SUV', '2012 Hyundai Veracruz SUV', '2012 Hyundai Sonata Hybrid Sedan', '2007 Hyundai Elantra Sedan', 
            '2012 Hyundai Accent Sedan', '2012 Hyundai Genesis Sedan', '2012 Hyundai Sonata Sedan', '2012 Hyundai Elantra Touring Hatchback', '2012 Hyundai Azera Sedan', '2012 Infiniti G Coupe IPL', 
            '2011 Infiniti QX56 SUV', '2008 Isuzu Ascender SUV', '2012 Jaguar XK XKR', '2012 Jeep Patriot SUV', '2012 Jeep Wrangler SUV', '2012 Jeep Liberty SUV', '2012 Jeep Grand Cherokee SUV', 
            '2012 Jeep Compass SUV', '2008 Lamborghini Reventon Coupe', '2012 Lamborghini Aventador Coupe', '2012 Lamborghini Gallardo LP 570-4 Superleggera', '2001 Lamborghini Diablo Coupe', 
            '2012 Land Rover Range Rover SUV', '2012 Land Rover LR2 SUV', '2011 Lincoln Town Car Sedan', '2012 MINI Cooper Roadster Convertible', '2012 Maybach Landaulet Convertible', '2011 Mazda Tribute SUV',
            '2012 McLaren MP4-12C Coupe', '1993 Mercedes-Benz 300-Class Convertible', '2012 Mercedes-Benz C-Class Sedan', '2009 Mercedes-Benz SL-Class Coupe', '2012 Mercedes-Benz E-Class Sedan', 
            '2012 Mercedes-Benz S-Class Sedan', '2012 Mercedes-Benz Sprinter Van', '2012 Mitsubishi Lancer Sedan', '2012 Nissan Leaf Hatchback', '2012 Nissan NV Passenger Van', '2012 Nissan Juke Hatchback', 
            '1998 Nissan 240SX Coupe', '1999 Plymouth Neon Coupe', '2012 Porsche Panamera Sedan', '2012 Ram C/V Cargo Van Minivan', '2012 Rolls-Royce Phantom Drophead Coupe Convertible',
                '2012 Rolls-Royce Ghost Sedan', '2012 Rolls-Royce Phantom Sedan', '2012 Scion xD Hatchback', '2009 Spyker C8 Convertible', '2009 Spyker C8 Coupe', '2007 Suzuki Aerio Sedan', 
                '2012 Suzuki Kizashi Sedan', '2012 Suzuki SX4 Hatchback', '2012 Suzuki SX4 Sedan', '2012 Tesla Model S Sedan', '2012 Toyota Sequoia SUV', '2012 Toyota Camry Sedan', '2012 Toyota Corolla Sedan', 
                '2012 Toyota 4Runner SUV', '2012 Volkswagen Golf Hatchback', '1991 Volkswagen Golf Hatchback', '2012 Volkswagen Beetle Hatchback', '2012 Volvo C30 Hatchback', '1993 Volvo 240 Sedan', 
                '2007 Volvo XC90 SUV', '2012 smart fortwo Convertible'],
        ['abbey', 'airplane_cabin', 'airport_terminal', 'alley', 'amphitheater', 'amusement_arcade', 'amusement_park', 'anechoic_chamber', 'outdoor apartment_building', 'indoor apse', 'aquarium', 'aqueduct',
          'arch', 'archive', 'outdoor arrival_gate', 'art_gallery', 'art_school', 'art_studio', 'assembly_line', 'outdoor athletic_field', 'public atrium', 'attic', 'auditorium', 'auto_factory', 'badlands', 
          'indoor badminton_court', 'baggage_claim', 'shop bakery', 'exterior balcony', 'interior balcony', 'ball_pit', 'ballroom', 'bamboo_forest', 'banquet_hall', 'bar', 'barn', 'barndoor', 
          'baseball_field', 'basement', 'basilica', 'outdoor basketball_court', 'bathroom', 'batters_box', 'bayou', 'indoor bazaar', 'outdoor bazaar', 'beach', 'beauty_salon', 'bedroom', 'berth', 
          'biology_laboratory', 'indoor bistro', 'boardwalk', 'boat_deck', 'boathouse', 'bookstore', 'indoor booth', 'botanical_garden', 'indoor bow_window', 'outdoor bow_window', 'bowling_alley', 
          'boxing_ring', 'indoor brewery', 'bridge', 'building_facade', 'bullring', 'burial_chamber', 'bus_interior', 'butchers_shop', 'butte', 'outdoor cabin', 'cafeteria', 'campsite', 'campus', # che
          'natural canal', 'urban canal', 'candy_store', 'canyon', 'backseat car_interior', 'frontseat car_interior', 'carrousel', 'indoor casino', 'castle', 'catacomb', 'indoor cathedral', 
          'outdoor cathedral', 'indoor cavern', 'cemetery', 'chalet', 'cheese_factory', 'chemistry_lab', 'indoor chicken_coop', 'outdoor chicken_coop', 'childs_room', 'indoor church', 'outdoor church',
            'classroom', 'clean_room', 'cliff', 'indoor cloister', 'closet', 'clothing_store', 'coast', 'cockpit', 'coffee_shop', 'computer_room', 'conference_center', 'conference_room', 'construction_site',
              'control_room', 'outdoor control_tower', 'corn_field', 'corral', 'corridor', 'cottage_garden', 'courthouse', 'courtroom', 'courtyard', 'exterior covered_bridge', 'creek', 'crevasse',  # che
              'crosswalk', 'office cubicle', 'dam', 'delicatessen', 'dentists_office', 'sand desert', 'vegetation desert', 'indoor diner', 'outdoor diner', 'home dinette', 'vehicle dinette', 'dining_car', 
              'dining_room', 'discotheque', 'dock', 'outdoor doorway', 'dorm_room', 'driveway', 'outdoor driving_range', 'drugstore', 'electrical_substation', 'door elevator', 'interior elevator', 
              'elevator_shaft', 'engine_room', 'indoor escalator', 'excavation', 'indoor factory', 'fairway', 'fastfood_restaurant', 'cultivated field', 'wild field', 'fire_escape', 'fire_station', 
              'indoor firing_range', 'fishpond', 'indoor florist_shop', 'food_court', 'broadleaf forest', 'needleleaf forest', 'forest_path', 'forest_road', 'formal_garden', 'fountain', 'galley', 
              'game_room', 'indoor garage', 'garbage_dump', 'gas_station', 'exterior gazebo', 'indoor general_store', 'outdoor general_store', 'gift_shop', 'golf_course', 'indoor greenhouse', 
              'outdoor greenhouse', 'indoor gymnasium', 'indoor hangar', 'outdoor hangar', 'harbor', 'hayfield', 'heliport', 'herb_garden', 'highway', 'hill', 'home_office', 'hospital', 'hospital_room',
                'hot_spring', 'outdoor hot_tub', 'outdoor hotel', 'hotel_room', 'house', 'outdoor hunting_lodge', 'ice_cream_parlor', 'ice_floe', 'ice_shelf', 'indoor ice_skating_rink',
                  'outdoor ice_skating_rink', 'iceberg', 'igloo', 'industrial_area', 'outdoor inn', 'islet', 'indoor jacuzzi', 'indoor jail', 'jail_cell', 'jewelry_shop', 'kasbah', 'indoor kennel',
                    'outdoor kennel', 'kindergarden_classroom', 'kitchen', 'kitchenette', 'outdoor labyrinth', 'natural lake', 'landfill', 'landing_deck', 'laundromat', 'lecture_room', 'indoor library',
                      'outdoor library', 'outdoor lido_deck', 'lift_bridge', 'lighthouse', 'limousine_interior', 'living_room', 'lobby', 'lock_chamber', 'locker_room', 'mansion', 'manufactured_home', 
                      'indoor market', 'outdoor market', 'marsh', 'martial_arts_gym', 'mausoleum', 'medina', 'water moat', 'outdoor monastery', 'indoor mosque', 'outdoor mosque', 'motel', 'mountain', 
                      'mountain_snowy', 'indoor movie_theater', 'indoor museum', 'music_store', 'music_studio', 'outdoor nuclear_power_plant', 'nursery', 'oast_house', 'outdoor observatory', 'ocean',
                        'office', 'office_building', 'outdoor oil_refinery', 'oilrig', 'operating_room', 'orchard', 'outdoor outhouse', 'pagoda', 'palace', 'pantry', 'park', 'indoor parking_garage',
                          'outdoor parking_garage', 'parking_lot', 'parlor', 'pasture', 'patio', 'pavilion', 'pharmacy', 'phone_booth', 'physics_laboratory', 'picnic_area', 'indoor pilothouse', 
                          'outdoor planetarium', 'playground', 'playroom', 'plaza', 'indoor podium', 'outdoor podium', 'pond', 'establishment poolroom', 'home poolroom', 'outdoor power_plant', 
                          'promenade_deck', 'indoor pub', 'pulpit', 'putting_green', 'racecourse', 'raceway', 'raft', 'railroad_track', 'rainforest', 'reception', 'recreation_room', 
                          'residential_neighborhood', 'restaurant', 'restaurant_kitchen', 'restaurant_patio', 'rice_paddy', 'riding_arena', 'river', 'rock_arch', 'rope_bridge', 'ruin',
                            'runway', 'sandbar', 'sandbox', 'sauna', 'schoolhouse', 'sea_cliff', 'server_room', 'shed', 'shoe_shop', 'shopfront', 'indoor shopping_mall', 'shower',
                              'skatepark', 'ski_lodge', 'ski_resort', 'ski_slope', 'sky', 'skyscraper', 'slum', 'snowfield', 'squash_court', 'stable', 'baseball stadium', 'football stadium', 'indoor stage', 
                              'staircase', 'street', 'subway_interior', 'platform subway_station', 'supermarket', 'sushi_bar', 'swamp', 'indoor swimming_pool', 'outdoor swimming_pool', 'indoor synagogue', 
                              'outdoor synagogue', 'television_studio', 'east_asia temple', 'south_asia temple', 'indoor tennis_court', 'outdoor tennis_court', 'outdoor tent', 'indoor_procenium theater',
                                'indoor_seats theater', 'thriftshop', 'throne_room', 'ticket_booth', 'toll_plaza', 'topiary_garden', 'tower', 'toyshop', 'outdoor track', 'train_railway', 
                                'platform train_station', 'tree_farm', 'tree_house', 'trench', 'coral_reef underwater', 'utility_room', 'valley', 'van_interior', 'vegetable_garden', 'veranda',
                                  'veterinarians_office', 'viaduct', 'videostore', 'village', 'vineyard', 'volcano', 'indoor volleyball_court', 'outdoor volleyball_court', 'waiting_room', 
                                  'indoor warehouse', 'water_tower', 'block waterfall', 'fan waterfall', 'plunge waterfall', 'watering_hole', 'wave', 'wet_bar', 'wheat_field', 'wind_farm', 'windmill', 
                                  'barrel_storage wine_cellar', 'bottle_storage wine_cellar', 'indoor wrestling_ring', 'yard', 'youth_hostel']                
    ]

    return task_class_names