# Standard library imports
import os
import random

# Third party imports
import cv2
import torch
import numpy as np
import pandas as pd
from scipy import stats
from torch.utils.data import Subset, DataLoader

# Local Imports
from dataset import CelebAHQ
from clf import Classifier
from explainer import Explainer

"""
Functions to load models and data
"""
def load_classifier(state_dict_path=None, eval_mode=False):
    clf = Classifier(fine_tune=False)
    if state_dict_path is not None:
        clf.load_state_dict(torch.load(state_dict_path))
        print("Loading from state dict")
    else:
        print("Loading untrained model")
    if eval_mode == True:
        clf.eval()
        print("classifier loaded in eval mode")
    return clf

def load_explainer(clf, device, bkgd, num_channels=3, state_dict_path=None, eval_mode=False):
    explainer = Explainer(clf, device, bkgd, num_channels=num_channels)
    if state_dict_path is not None:
        explainer.load_state_dict(torch.load(state_dict_path), map)
        print("Loading from state dict")
    else:
        print("Loading untrained explainer")
    if eval_mode == True:
        explainer.eval()
        print("explainer loaded in eval mode")

    return explainer

def load_train_val_data(root_dir, img_dir, csv_file, model_type="clf"):
    # Open original csv file
    csv_path = os.path.join(root_dir, csv_file)
    csv = pd.read_csv(csv_path, index_col=[0])

    # Load data
    train_idx = pd.read_csv(os.path.join("data", model_type + "_train_idx.csv"), header=None)
    val_idx = pd.read_csv(os.path.join("data", model_type + "_val_idx.csv"), header=None)
    train_data = Subset(CelebAHQ(img_dir, csv, "Smiling", transform=False), train_idx[0].values)
    val_data = Subset(CelebAHQ(img_dir, csv, "Smiling", transform=False), val_idx[0].values)

    return train_data, val_data

"""
Functions for analysis
"""
def process_data(results):
    suff = []
    necc = []
    l0 = []
    for i in range(len(results)):
        suff.append(results[i]['suff'])
        necc.append(results[i]['necc'])
        l0.append(results[i]['l0'])
    suff = np.array(suff)
    necc = np.array(necc)
    l0 = np.array(l0)
    return suff, necc, l0

def calc_95ci(array):
    mean = array.mean(axis=0)
    sem = stats.sem(array, axis=0)
    ci = 1.90 * sem
    return mean, ci

def compute_metrics_batch(clf, x, attr, y_hat, y_0, bkgd, thresholds):
    suff = []
    necc = []
    l0 = []
    for t in thresholds:
        t_mask = threshold_attr(attr, t)
        suff_t, necc_t, l1_t = compute_metrics(
            clf, x, y_hat, y_0, t_mask, bkgd
        )
        suff.append(suff_t)
        necc.append(necc_t)
        l0.append(l1_t)
    return suff, necc, l0

def compute_metrics(clf, x, y_hat, y_0, binary_mask, bkgd):
    # Compute f_x, f_xS, f_xSc
    x_S = binary_mask*x + (1-binary_mask)*bkgd
    x_Sc = (1-binary_mask)*x + binary_mask*bkgd
    f_xS = clf(x_S).item()
    f_xSc = clf(x_Sc).item()

    # Sufficiency
    suff = abs(y_hat - f_xS)

    # Necessity
    necc = abs(f_xSc - y_0)

    # L0
    l0 = torch.mean(torch.abs(binary_mask)).item()

    return suff, necc, l0

"""
Functions to process images
"""
def threshold_attr(attr, t):
    return (attr >= t)*1.0

def process_attr(attr, top_kp, norm_type):
    # Clamp scores
    attr = torch.clamp(torch.abs(attr), min=0)
    if norm_type == "top_kp":
        flat_attr = attr.view(-1)
        nonzero_attr = flat_attr[flat_attr > 0]

        # Calculate top k
        top_k = int(len(nonzero_attr) * top_kp / 100)
        if len(nonzero_attr) == 0:
            top_k_value = 0
        else:
            top_k_value = torch.topk(nonzero_attr, top_k, largest=True).values[-1].item()

        # Set top k scores to 1
        attr[attr >= top_k_value] = 1
        min_top_k_value = top_k_value

        # Interpolate remaining scores linearly
        attr[attr < top_k_value] /= min_top_k_value
    elif norm_type == "min/max":
        attr = (attr - torch.min(attr))/(torch.max(attr) - torch.min(attr))

    # Convert to grayscale image
    if attr.shape[0] == 3:
        gs_attr = rgb_to_grayscale(attr)
    else:
        gs_attr = attr
    
    return gs_attr
    
def rgb_to_grayscale(tensor):
    # Ensure the input tensor is 3xHxW
    assert tensor.shape[0] == 3, "Input tensor must have 3 channels (RGB)"
    
    # Apply the grayscale conversion formula
    grayscale_tensor = 0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]
    
    return grayscale_tensor.unsqueeze(0)