import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from sklearn.metrics import accuracy_score
from statsmodels.stats.proportion import proportion_confint

from sae.utils import filter_out_nosiy_activation, load_activation_data
from utils.dataset_loader import ImageDatasetLoader


def load_class_names(root, dataset_name):
    if dataset_name == "imagenet":
        with open(f"{root}/data/etc/{dataset_name}_classnames.txt", "r") as f:
            class_names = []
            for i, line in enumerate(f.readlines()):
                class_name = " ".join(line.strip().split(" ")[1:])
                class_names.append(class_name)


def get_image_text_cos(clip_model, clip_processor, heatmap_images_high, label_embedding, device):
    high_masked_embed = get_image_embedding(clip_model, clip_processor, heatmap_images_high, device)
    cos_sim_matrix = high_masked_embed @ label_embedding.T
    label_masked_high = cos_sim_matrix.max(dim=1).values.cpu().numpy()
    return label_masked_high


def get_text_embedding(clip_model, clip_processor, text, device):
    """Computes the text embedding using CLIP."""
    inputs = clip_processor(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = clip_model.get_text_features(**inputs)
    norm_embedding = F.normalize(outputs, p=2, dim=1)
    return norm_embedding


def generate_random_patch_mask(batch_dict, concept_only_mask, image_size=256, patch_size=16):
    mask_concept_only = []
    mask_concept_exclude = []
    for i, image in enumerate(batch_dict["image"]):
        mask = concept_only_mask[i]
        percent = np.sum(mask) / (image_size**2)

        num_patches_per_side = image_size // patch_size
        total_patches = num_patches_per_side**2
        num_mask_patches = int(percent * total_patches)

        chosen_indices = np.random.choice(total_patches, num_mask_patches, replace=False)

        patch_mask = np.zeros((num_patches_per_side, num_patches_per_side), dtype=np.uint8)
        for idx in chosen_indices:
            row = idx // num_patches_per_side
            col = idx % num_patches_per_side
            patch_mask[row, col] = 1

        full_mask = np.kron(patch_mask, np.ones((patch_size, patch_size), dtype=np.uint8))
        full_mask = np.expand_dims(full_mask, axis=-1)  # Ma
        image_array = np.array(image)

        blended = image_array * full_mask
        blended = np.clip(blended, 0, 255).astype(np.uint8)
        mask_concept_only.append(Image.fromarray(blended))

        blended = image_array * (1 - full_mask)
        blended = np.clip(blended, 0, 255).astype(np.uint8)
        mask_concept_exclude.append(Image.fromarray(blended))

    return mask_concept_only, mask_concept_exclude


def masking(batch_dict, masks, resize_size=256, blend_rate=0.0, gamma=0.5, reverse=False):
    masked_images = []
    processed_masks = []
    for i, image in enumerate(batch_dict["image"]):
        image_array = np.array(image.resize((resize_size, resize_size)))[..., :3].astype(np.float32)

        mask = (masks[i] - masks[i].min()) / (masks[i].max() - masks[i].min() + 1e-10)
        mask = np.expand_dims(mask, axis=-1)  # Make shape (H, W, 1) to broadcast over RGB

        mask = mask**gamma

        if reverse:
            mask = 1 - mask

        blended = image_array * (blend_rate + (1 - blend_rate) * mask)
        blended = np.clip(blended, 0, 255).astype(np.uint8)

        masked_images.append(Image.fromarray(blended))
        processed_masks.append(mask.squeeze(-1))
    return masked_images, processed_masks


def get_image_embedding(clip_model, clip_processor, image, device):
    inputs = clip_processor(images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = clip_model.get_image_features(**inputs)
    norm_embedding = F.normalize(outputs, p=2, dim=1)
    return norm_embedding


def load_training_stats(root, dataset_name):

    train_mean_var_activation = np.load(f"{root}/data/{dataset_name}_analysis/train_mean_var_activation_per_class.npy")

    high_activating_ratio = np.load(f"{root}/data/{dataset_name}_analysis//high_activation_ratio.npy")

    return {
        "train_mean_var_activation": train_mean_var_activation,
        "high_activating_ratio": high_activating_ratio,
    }


def load_data(dataset_name, split="test", label_name="label", root="./"):

    test_dataset = ImageDatasetLoader.load_dataset(dataset_name, seed=1, split=split)
    test_labels = np.array(test_dataset[label_name])

    test_activations = load_activation_data(root, dataset_name, split=split)
    test_activations = filter_out_nosiy_activation(root, test_activations)

    train_dataset = ImageDatasetLoader.load_dataset(dataset_name, seed=1, split="train")
    train_labels = np.array(train_dataset[label_name])

    return {
        "train_dataset": train_dataset,
        "train_labels": train_labels,
        "test_dataset": test_dataset,
        "test_labels": test_labels,
        "test_activations": test_activations,
    }


def load_pred_results(root, dataset_name, subset="imagenet"):
    pred_results = pd.read_csv(f"{root}/data/{dataset_name}_analysis/{subset}_predictions.csv")
    return pred_results


def get_high_low_group_acc(pred_results, test_class_indices, test_class_activation, latent_idx, top_n=10):
    if pred_results is None:
        return {
            "high_group_acc": -1,
            "low_group_acc": -1,
        }
    sample_gt = pred_results["gt_label"].to_numpy()[test_class_indices]
    sample_pred = pred_results["pred_label"].to_numpy()[test_class_indices]

    latent_act = test_class_activation[:, latent_idx]
    high_threshold = np.percentile(latent_act, 100 - top_n)
    low_threshold = np.percentile(latent_act, top_n)

    high_indices = np.where(latent_act >= high_threshold)[0]
    high_group_acc = accuracy_score(sample_gt[high_indices], sample_pred[high_indices])

    low_indices = np.where(latent_act <= low_threshold)[0]
    low_group_acc = accuracy_score(sample_gt[low_indices], sample_pred[low_indices])
    return {
        "high_group_acc": high_group_acc,
        "low_group_acc": low_group_acc,
    }


def get_cls_acc(pred_results, test_class_indices, test_class_activation):
    if pred_results is None:
        return {
            "acc": -1,
            "lower": -1,
            "upper": -1,
        }

    sample_gt = pred_results["gt_label"].to_numpy()[test_class_indices]
    sample_pred = pred_results["pred_label"].to_numpy()[test_class_indices]
    wrong_indices = np.where(sample_gt != sample_pred)[0]
    correct_indices = np.where(sample_gt == sample_pred)[0]
    acc = accuracy_score(sample_gt, sample_pred)
    valid_indices = np.where(np.where(test_class_activation > 0.5, 1, 0).sum(0) > 0)[0]
    lower, upper = proportion_confint(len(correct_indices), len(test_class_indices), alpha=0.05, method="normal")
    return {
        "acc": acc,
        "lower": lower,
        "upper": upper,
    }
