from IPython.display import display
import numpy as np
import os
import pandas as pd
import pickle
from PIL import Image
from sklearn.feature_selection import mutual_info_classif
import sys
import torch
from torchvision import transforms
from torch.utils.data import Dataset

sys.path.append(os.path.join(os.path.dirname(__file__), '../Dependencies/barlow'))
from standard_utils import *
from decision_tree_utils import *
from visualize_utils import *

def save_image(image, name, dpi = 80):
    height, width, depth = image.shape
    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize =figsize)
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')
    ax.imshow(image)
    plt.savefig(name)
    plt.close()  
    
###
# Setup
###

# Most of these are placeholder values so the code runs

DATA_PATH_DICT = {'ImageNet': './'}
DATA = 'ImageNet'
MODEL_PATH = '../Dependencies/barlow/models/'
dataset_function = getattr(datasets, DATA)

###
# Copied from barlow/failure_explanation_sample.ipynb 
###
            
# Train decision tree
def train_decision_tree(train_sparse_features, train_failure, max_depth=1, criterion="entropy"):
    num_true = np.sum(train_failure)
    num_false = np.sum(np.logical_not(train_failure))
    rel_weight = num_false/num_true
    class_weight_dict = {0: 1, 1: rel_weight}

    decision_tree = CustomDecisionTreeClassifier(
        max_depth=max_depth, criterion=criterion, class_weight=class_weight_dict)
    decision_tree.fit_tree(
        train_sparse_features, train_failure)
    return decision_tree

# Select leaf nodes with highest importance value i.e highest contribution to average leaf error rate
def important_leaf_nodes(decision_tree, precision_array, recall_array):
    leaf_ids = decision_tree.leaf_ids
    leaf_precision = precision_array[leaf_ids]
    leaf_recall = recall_array[leaf_ids]
    leaf_precision_recall = leaf_precision*leaf_recall

    important_leaves = np.argsort(-leaf_precision_recall)
    return leaf_ids[important_leaves]

###
# Based on  barlow/failure_explanation_sample.ipynb 
###

# Modified to not have a 'root' directory as our setup passes full file paths
class FailureExplanationDataset(Dataset):
    """Failure Explanation Dataset."""

    def __init__(self, csv_file, class_name, grouping="label", transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with labels and predictions.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.metadata_frame = pd.read_csv(csv_file)
        
        if grouping == "label":
            select_indices = (self.metadata_frame["Labels"] == class_name)
        else:
            select_indices = (self.metadata_frame["Predictions"] == class_name)
        self.metadata_frame = self.metadata_frame[select_indices]

        class_names_labels = self.metadata_frame["Labels"].to_numpy()
        class_names_preds = self.metadata_frame["Predictions"].to_numpy()
        class_names = np.concatenate([class_names_preds, class_names_labels])
        
        unique_class_names = np.unique(class_names)
        self.class_indices_dict = dict(zip(unique_class_names, np.arange(len(unique_class_names))))
        
        self.transform = transform

    def __len__(self):
        return len(self.metadata_frame)

    def __getitem__(self, idx):
        img_name = self.metadata_frame.iloc[idx, 0]
        prediction = self.metadata_frame.iloc[idx, 1]
        label = self.metadata_frame.iloc[idx, 2]
        
        img = Image.open(img_name)
        img = img.convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label, prediction
    

# Modified to use either a robust imagenet model or a custom robust model
# Modified to compute visualizations more efficiently (no redundancy)
def sample_failure_explanation(csv_file, class_name, use_imagenet = True, model = None, grouping = "label", max_depth = 11, num_show = 10, out_dir = '', skip_feature_vis = False):

    # Select which robust model is used for the feature extraction
    if use_imagenet:
        robust_model_name = 'imagenet_l2_3_0.pt'
        robust_model = load_model(robust_model_name, MODEL_PATH, dataset_function(DATA_PATH_DICT[DATA]))
    else:
        robust_model = model
        
    # Setup the dataloader
    data_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
            ]) 

    sample_dataset = FailureExplanationDataset(csv_file, class_name, grouping=grouping, transform=data_transform)

    batch_size = 32
    data_loader = torch.utils.data.DataLoader(sample_dataset,
                                              batch_size=batch_size, 
                                              shuffle=False, num_workers=4)

    # Extract the features
    train_features, train_preds, train_labels = [], [], []
    total = 0
    for _, (ims, labels, preds) in enumerate(data_loader):
        ims = ims.cuda()
        batch_size = ims.shape[0]
        (_, features), _ = robust_model(ims, with_latent=True)
        features = features.detach().cpu().numpy()

        labels = np.array(labels)
        preds = np.array(preds)

        train_features.append(features)
        train_labels.append(labels)
        train_preds.append(preds)

        total = total + batch_size

    train_features = np.concatenate(train_features, axis=0)
    train_labels = np.concatenate(train_labels, axis=0)
    train_preds = np.concatenate(train_preds, axis=0)

    train_failure = np.logical_not(train_preds == train_labels)
    train_success = np.logical_not(train_failure)
    train_base_error_rate = np.sum(train_failure)/len(train_failure)
    
    important_features_indices = np.arange(train_features.shape[1])

    # Train the decision tree to predict where the model makes mistakes using those features
    decision_tree = train_decision_tree(train_features, train_failure, max_depth=max_depth, criterion="entropy")
    train_precision, train_recall, train_ALER = decision_tree.compute_precision_recall(
        train_features, train_failure)
    
    error_rate_array, error_coverage_array = decision_tree.compute_leaf_error_rate_coverage(
                                                train_features, train_failure)
    important_leaf_ids = important_leaf_nodes(decision_tree, error_rate_array, error_coverage_array)
    
    ###
    # Visualize the results
    ###
    
    # Start with a summary of the data
    print('###')
    print('# General Info')
    print('###')
    print()
    print('Class: ', class_name)
    print('Overall Error Rate: ', np.round(train_base_error_rate, 3))
    print()
    print()
    
    print('###')
    print('# Tree Info')
    print('###')
    print()
    print('Precision: ', np.round(train_precision, 3))
    print('Recall: ', np.round(train_recall, 3))
    print()
    print()

    out = {}
    feature_count = 0
    node_info = {}
    for i, leaf_id in enumerate(important_leaf_ids[:min(num_show, len(important_leaf_ids))]):
        decision_path = decision_tree.compute_decision_path(leaf_id)
        leaf_failure_indices = decision_tree.compute_leaf_truedata(train_features, train_failure, leaf_id)
        leaf_success_indices = decision_tree.compute_leaf_truedata(train_features, train_success, leaf_id)
        out[i] = np.concatenate([leaf_failure_indices, leaf_success_indices])
        
        print('###')
        print('# Group ', i + 1)
        print('### ')
        print()
        
        if len(leaf_failure_indices) != 0:
            images = display_failures(leaf_id, leaf_failure_indices, data_loader, grouping, num_images=6)
        
            images = np.vstack(images)
            save_image(images, '{}/group_{}.png'.format(out_dir, i + 1))
        
        
            print('Group Error Rate: ', np.round(error_rate_array[leaf_id], 3))
            print('Percentage of all Errors in this Group: ', np.round(error_coverage_array[leaf_id] , 3))
            print('Features that define this Group:')
            for node in decision_path:
                node_id, feature_id, feature_threshold, direction = node 
                if node_id not in node_info:
                    node_info[node_id] = {'id': feature_id, 'threshold': feature_threshold, 'name': 'F{}'.format(feature_count)}
                    feature_count += 1
                node_name = node_info[node_id]['name']            

                if direction == 'left':
                    print_str = '-  Feature {} is less present'.format(node_name)
                else:
                    print_str = '-  Feature {} is more present'.format(node_name)
                print(print_str)
        else:
            print('Group Error Rate: 0.0')
               
        print()
        print()
            
    
    if not skip_feature_vis:
        indices = np.arange(len(data_loader.dataset))
        for node_id in node_info:
            feature_id = node_info[node_id]['id']
            feature_name = node_info[node_id]['name']
            images = feature_visualization(robust_model, indices, train_features, feature_id,
                              data_loader, grouping, num_images=6)
            images = np.vstack(images)
            save_image(images, '{}/feature_{}.png'.format(out_dir, feature_name))
            
    return out
        