import json
import os
from itertools import chain

import matplotlib.patches as patches
import numpy as np
import pandas as pd
import torch
from sklearn import metrics

"""
Functions for analysis
"""
def compute_metrics_batch(clf, x, attr, y_hat, y_0, bkgd, thresholds):
    suff = []
    necc = []
    l0 = []
    for t in thresholds:
        t_mask = threshold_attr(attr, t)
        suff_t, necc_t, l1_t = compute_metrics(
            clf, x, y_hat, y_0, t_mask, bkgd
        )
        suff.append(suff_t)
        necc.append(necc_t)
        l0.append(l1_t)
    return suff, necc, l0

def compute_metrics(clf, x, y_hat, y_0, binary_mask, bkgd):
    # Compute f_x, f_xS, f_xSc
    x_S = binary_mask*x + (1-binary_mask)*bkgd
    x_Sc = (1-binary_mask)*x + binary_mask*bkgd
    f_xS = clf(x_S).item()
    f_xSc = clf(x_Sc).item()

    # Sufficiency
    suff = abs(y_hat - f_xS)

    # Necessity
    necc = abs(f_xSc - y_0)

    # L0
    l0 = torch.mean(torch.abs(binary_mask)).item()

    return suff, necc, l0

"""
Functions to process images
"""
def threshold_attr(attr, t):
    return (attr >= t)*1.0

def process_attr(attr, top_kp, norm_type):
    # Clamp scores
    attr = torch.clamp(torch.abs(attr), min=0)
    if norm_type == "top_kp":
        flat_attr = attr.view(-1)
        nonzero_attr = flat_attr[flat_attr > 0]

        # Calculate top k
        top_k = int(len(nonzero_attr) * top_kp / 100)
        if len(nonzero_attr) == 0:
            top_k_value = 0
        else:
            top_k_value = torch.topk(nonzero_attr, top_k, largest=True).values[-1]

        # Set top k scores to 1
        attr[attr >= top_k_value] = 1
        min_top_k_value = top_k_value.item()

        # Interpolate remaining scores linearly
        attr[attr < top_k_value] /= min_top_k_value
    elif norm_type == "min/max":
        attr = (attr - torch.min(attr))/(torch.max(attr) - torch.min(attr))

    # Convert to grayscale image
    if attr.shape[0] == 3:
        gs_attr = rgb_to_grayscale(attr)
    else:
        gs_attr = attr
    
    return gs_attr

def rgb_to_grayscale(tensor):
    # Ensure the input tensor is 3xHxW
    assert tensor.shape[0] == 3, "Input tensor must have 3 channels (RGB)"
    
    # Apply the grayscale conversion formula
    grayscale_tensor = 0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]
    
    return grayscale_tensor.unsqueeze(0)


def window(image, window_level, window_width):
    image_min = window_level - window_width // 2
    image_max = window_level + window_width // 2
    image[image < image_min] = image_min
    image[image > image_max] = image_max
    image = (image - image_min) / (image_max - image_min)
    return image


"""
def random_mask(x, ref_img):
    m = torch.rand(
        (512, 512),
        dtype=torch.float32,
        requires_grad=False,
    ) 
    m = m.repeat(3, 1, 1)
    return x*m + (1-m)*ref_img

def calc_base_score(x, model, device):
    std_pixel = torch.std(x)
    mean_pixel = torch.mean(x)
    x.requires_grad_(False)
    '''
    n = (
        std_pixel
        * torch.randn(
            (10, *x.shape[1:]),
            dtype=torch.float32,
            device=device,
            requires_grad=False,
        )
        + mean_pixel
    )
    '''
    n = mean_pixel * torch.ones(
            (10, *x.shape[1:]),
            dtype=torch.float32,
            device=device,
            requires_grad=False,
        )
    return torch.mean(model(n)).item()

"""

bad_image_idx = [
    11933,
    13167,
    14786,
    18889,
    23057,
    27930,
    28280,
    30219,
    31099,
    49547,
    51148,
    54803,
    60439,
    60664,
    66835,
    70724,
    71545,
    73302,
    80326,
    87191,
    92941,
    95408,
    98876,
    102752,
    103300,
    107578,
    112766,
    120811,
    122325,
    122631,
    124568,
    129830,
    131375,
    134394,
    135825,
    140869,
    141009,
    141083,
    153031,
    165440,
    173385,
    193620,
    195753,
    201277,
    208877,
    219064,
    220852,
    225441,
    227901,
    236187,
    240481,
    241565,
    247716,
    250227,
    261277,
    264452,
    265789,
    268057,
    269724,
    272157,
    283192,
    300121,
    305078,
    311228,
    314908,
    315030,
    316886,
    341616,
    342265,
    343591,
    346422,
    349896,
    355273,
    377314,
    377452,
    379891,
    383257,
    393672,
    395247,
    405499,
    411411,
    413143,
    416566,
    421182,
    423681,
    429387,
    436461,
    437250,
    437405,
    444053,
    446739,
    447970,
    453793,
    459729,
    461317,
    461641,
    474199,
    475867,
    479760,
    492856,
    497749,
    499633,
    505237,
    514100,
    516935,
    518755,
    521637,
    524714,
    530357,
    530500,
    537159,
    545155,
    550627,
    550817,
    561374,
    562532,
    563396,
    563525,
    578539,
    578874,
    586089,
    594891,
    599572,
]
