from matplotlib import pyplot as plt
import random
import os
import sys
import json
import shutil
from collections import defaultdict
import pandas as pd
import numpy as np
import cv2
import torch as ch

from robustness.tools.vis_tools import show_image_row, show_image_column


def get_failing_classes():
    imagenet_metadata_dir = './imagenet_features'

    nonrobust_logits_file = os.path.join(imagenet_metadata_dir, 'nonrobust_train_logits.npy')
    robust_logits_file = os.path.join(imagenet_metadata_dir, 'robust_train_logits.npy')
    labels_file = os.path.join(imagenet_metadata_dir, 'train_labels.npy')

    nonrobust_logits = np.load(nonrobust_logits_file)
    nonrobust_preds = np.argmax(nonrobust_logits, axis=1)
    robust_logits = np.load(robust_logits_file)
    robust_preds = np.argmax(robust_logits, axis=1)
    labels = np.load(labels_file)




    num_classes = 1000
    failures_robust_labels = np.zeros(num_classes)
    failures_nonrobust_labels = np.zeros(num_classes)
    failures_robust_preds = np.zeros(num_classes)
    failures_nonrobust_preds = np.zeros(num_classes)
    robust_preds_arr = np.zeros(num_classes)
    nonrobust_preds_arr = np.zeros(num_classes)
    images_per_class = np.zeros(num_classes)


    for i in range(1000):
        failures_robust_preds[i] = np.sum((robust_preds == i) & (np.logical_not(robust_preds == labels)))
        failures_nonrobust_preds[i] = np.sum((nonrobust_preds == i) & (np.logical_not(nonrobust_preds == labels)))

        failures_robust_labels[i] = np.sum((labels == i) & (np.logical_not(robust_preds == labels)))
        failures_nonrobust_labels[i] = np.sum((labels == i) & (np.logical_not(nonrobust_preds == labels)))

        images_per_class[i] = np.sum(labels==i)

        robust_preds_arr[i] = np.sum(robust_preds == i)
        nonrobust_preds_arr[i] = np.sum(nonrobust_preds == i)
        
    failure_ratio_robust_preds = failures_robust_preds/np.float32(robust_preds_arr)
    failure_ratio_nonrobust_preds = failures_nonrobust_preds/np.float32(nonrobust_preds_arr)

    sorted_failure_ratios_robust_preds = np.argsort(-failure_ratio_robust_preds)
    sorted_failure_ratios_nonrobust_preds = np.argsort(-failure_ratio_nonrobust_preds)

    most_failing_classes_robust_preds = sorted_failure_ratios_robust_preds[:50]
    least_failing_classes_robust_preds = sorted_failure_ratios_robust_preds[-50:]

    most_failing_classes_nonrobust_preds = sorted_failure_ratios_nonrobust_preds[:50]
    least_failing_classes_nonrobust_preds = sorted_failure_ratios_nonrobust_preds[-50:]


    failure_ratio_robust_labels = failures_robust_labels/np.float32(images_per_class)
    failure_ratio_nonrobust_labels = failures_nonrobust_labels/np.float32(images_per_class)

    sorted_failure_ratios_robust_labels = np.argsort(-failure_ratio_robust_labels)
    sorted_failure_ratios_nonrobust_labels = np.argsort(-failure_ratio_nonrobust_labels)

    most_failing_classes_robust_labels = sorted_failure_ratios_robust_labels[:50]
    least_failing_classes_robust_labels = sorted_failure_ratios_robust_labels[-50:]

    most_failing_classes_nonrobust_labels = sorted_failure_ratios_nonrobust_labels[:50]
    least_failing_classes_nonrobust_labels = sorted_failure_ratios_nonrobust_labels[-50:]

    
    li = [most_failing_classes_robust_preds, least_failing_classes_robust_preds, 
          most_failing_classes_nonrobust_preds, least_failing_classes_nonrobust_preds, 
          most_failing_classes_robust_labels, least_failing_classes_robust_labels, 
          most_failing_classes_nonrobust_labels, least_failing_classes_nonrobust_labels]
    return li



def print_with_stars(print_str, total_count=115, prefix="", suffix="", star='*'):
    str_len = len(print_str)
    left_len = (total_count - str_len)//2
    right_len = total_count - left_len - str_len
    final_str = "".join([star]*(left_len)) + print_str + "".join([star]*(right_len))
    final_str = prefix + final_str + suffix
    print(final_str)


def aggregate_results(df):
    class_info_file = open('./data/class_info.json')
    class_info_data = json.load(class_info_file)

    synset_dict = {}
    gloss_dict = defaultdict(str)
    superclasses_dict = defaultdict(str)
    for class_dict in class_info_data:
        class_id = class_dict['cid']    
            
        synset = str(', '.join(class_dict['synset']))
        synset_dict[class_id] = synset    
        
        gloss = class_dict['gloss']
        gloss_dict[class_id] = gloss
        
    superclasses = np.load('data/superclasses.npy', allow_pickle=True).item()
    for superclass, class_list in superclasses.items():
        for class_id in class_list:
            superclasses_dict[class_id] = superclass




    common_reasons = []

    blocked_workers = []

    inputs_dict = {}
    answers_dict = defaultdict(list)
    
    for row in df.iterrows():
        index, content = row
        WorkerId = content['WorkerId']
        
        class_index = int(content['Answer.class_index'])
        feature_index = int(content['Answer.feature_index'])
        feature_rank = int(content['Input.feature_rank'])
        
        main_index = str(class_index) + '_' + str(feature_index) + '_' + str(feature_rank)
        
        
        main_answer = content['Answer.main_question']
        confidence = content['Answer.confidence']

        reasons = content['Answer.reasons']
        if (reasons.lower() in common_reasons):
            reasons = ""
            
        if not(content['AssignmentStatus'] == "Rejected") and not(WorkerId in blocked_workers):
            answers_dict[main_index].append((WorkerId, main_answer, confidence, reasons))
            
    return answers_dict

def get_answers_dict():
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    answers_dict = aggregate_results(complete_df)
    return answers_dict
    
def get_causal_features_dict(threshold=3, return_rank=False, validate_heatmap=False):
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    causal_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_causal = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer == 'main_object':
                num_causal = num_causal + 1
        
        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        if num_causal >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if return_rank:
                causal_features_dict[class_index].append((feature_index, feature_rank))
            else:
                causal_features_dict[class_index].append(feature_index)
            
    return causal_features_dict
    
def get_spurious_features_dict(threshold=3, return_rank=False, validate_heatmap=True):    
    same_dict = aggregate_heatmap_validation()

    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer in ['separate_object', 'background']:
                num_spurious = num_spurious + 1

        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        append_feature = True
        if validate_heatmap and (same_dict[new_key] <= 0.5):
            append_feature = False
                
        if num_spurious >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if append_feature:
                if return_rank:
                    spurious_features_dict[class_index].append((feature_index, feature_rank))
                else:
                    spurious_features_dict[class_index].append(feature_index)
            
    return spurious_features_dict

def get_specific_spurious_dict(threshold=3, return_rank=False, validate_heatmap=False, choice='background'):
    assert choice in ['background', 'separate_object']
    
    same_dict = aggregate_heatmap_validation()

    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer == choice:
                num_spurious = num_spurious + 1

        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        append_feature = True
        if validate_heatmap and (same_dict[new_key] <= 0.5):
            append_feature = False
                
        if num_spurious >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if append_feature:
                if return_rank:
                    spurious_features_dict[class_index].append((feature_index, feature_rank))
                else:
                    spurious_features_dict[class_index].append(feature_index)
            
    return spurious_features_dict



def get_unique_workers():
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    worker_ids = np.array(complete_df['WorkerId'])
    worker_ids_unique = np.unique(worker_ids)
    return worker_ids_unique, len(complete_df)

def aggregate_heatmap_validation():
    complete_df_file_name = 'validation_results_complete.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    common_reasons = []
    blocked_workers = []

    inputs_dict = {}
    answers_dict = defaultdict(list)
    
    for row in complete_df.iterrows():
        index, content = row
        WorkerId = content['WorkerId']
        
        class_index = int(content['Input.class_id'])
        feature_index = int(content['Input.feature_id'])
        main_index = str(class_index) + '_' + str(feature_index)
        
        main_answer = content['Answer.main_question']
        confidence = content['Answer.confidence']
        reasons = content['Answer.reasons']
            
        if not(content['AssignmentStatus'] == "Rejected") and not(WorkerId in blocked_workers):
            answers_dict[main_index].append((WorkerId, main_answer, confidence, reasons))
            
    same_dict = defaultdict(float)
    for key, answers in answers_dict.items():
        class_index, feature_index = key.split('_')
        class_index, feature_index = int(class_index), int(feature_index)

        total = 0
        num_same = 0
        for answer in answers:
            WorkerId, main_answer, confidence, reason = answer

            if main_answer == 'same':
                num_same = num_same + 1
            total = total + 1

        same_dict[key] = num_same/total
    return same_dict


def get_core_spurious_features_dict(csv_path='whole_imagenet_results/approved_results_new.csv'):
    complete_df = pd.read_csv(csv_path)
    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    core_features_dict = defaultdict(list)
    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer in ['separate_object', 'background']:
                num_spurious = num_spurious + 1                

        class_index, feature_index, feature_rank = key.split('_')
        class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)
        if num_spurious >= 3:
            spurious_features_dict[class_index].append(feature_index)
        else:
            core_features_dict[class_index].append(feature_index)
            
    return core_features_dict, spurious_features_dict


class MTurk_Results:
    def __init__(self, csv_path):
        self.csv_path = csv_path
        self.dataframe = pd.read_csv(self.csv_path)
        
        self.aggregate_results(self.dataframe)
        
        self.class_feature_maps(self.answers_dict)
        self.core_spurious_labels_dict(self.answers_dict)
        self.spurious_feature_lists(self.answers_dict)
        
    def aggregate_results(self, dataframe):
        answers_dict = defaultdict(list)
        reasons_dict = defaultdict(list)
        feature_rank_dict = defaultdict(int)
        for row in dataframe.iterrows():
            index, content = row
            WorkerId = content['WorkerId']

            class_index = int(content['Answer.class_index'])
            feature_index = int(content['Answer.feature_index'])
            feature_rank = int(content['Input.feature_rank'])
            
            key = str(class_index) + '_' + str(feature_index)

            main_answer = content['Answer.main_question']
            confidence = content['Answer.confidence']
            reasons = content['Answer.reasons']
            
            if not(content['AssignmentStatus'] == "Rejected"):
                answers_dict[key].append((WorkerId, main_answer, confidence, reasons))
                reasons_dict[key].append(reasons)
                
            feature_rank_dict[key] = feature_rank

        self.answers_dict = answers_dict
        self.feature_rank_dict = feature_rank_dict
        self.reasons_dict = reasons_dict
    
    def core_spurious_labels_dict(self, answers_dict):
        core_features_dict = defaultdict(list)
        spurious_features_dict = defaultdict(list)
        
        core_spurious_dict = {}
        core_list = []
        spurious_list = []
        for key, answers in answers_dict.items():
            class_index, feature_index = key.split('_')
            class_index, feature_index = int(class_index), int(feature_index)
            
            num_spurious = 0
            for answer in answers:
                main_answer = answer[1]
                if main_answer in ['separate_object', 'background']:
                    num_spurious = num_spurious + 1                

            if num_spurious >= 3:
                spurious_features_dict[class_index].append(feature_index)
                core_spurious_dict[key] = 'spurious'
                spurious_list.append(key)
                
            else:
                core_features_dict[class_index].append(feature_index)
                core_spurious_dict[key] = 'core'
                core_list.append(key)
                
        self.core_spurious_dict = core_spurious_dict
        self.core_list = core_list
        self.spurious_list = spurious_list
        
        self.core_features_dict = core_features_dict
        self.spurious_features_dict = spurious_features_dict
    
    def spurious_feature_lists(self, answers_dict):
        background_list = []
        separate_list = []
        ambiguous_list = []
        for key, answers in answers_dict.items():
            num_background = 0
            num_separate = 0
            for answer in answers:
                main_answer = answer[1]
                if main_answer == 'background':
                    num_background = num_background + 1
                elif main_answer == 'separate_object':
                    num_separate = num_separate + 1
                                
            if num_background >= 3:
                background_list.append(key)
            elif num_separate >= 3:
                separate_list.append(key)
            elif (num_background + num_separate) >= 3:
                ambiguous_list.append(key)
                
        self.background_list = background_list
        self.separate_list = separate_list
        self.ambiguous_list = ambiguous_list
        
        
    def class_feature_maps(self, answers_dict):
        keys_list = answers_dict.keys()
        
        feature_to_classes_dict = defaultdict(list)
        class_to_features_dict = defaultdict(list)
        
        for key in keys_list:
            class_index, feature_index = key.split('_')
            class_index = int(class_index)
            feature_index = int(feature_index)
            
            feature_to_classes_dict[feature_index].append(class_index)
            class_to_features_dict[class_index].append(feature_index)
            
            
        self.class_to_features_dict = class_to_features_dict
        self.feature_to_classes_dict = feature_to_classes_dict
    
    
class MTurk_Validation_Results:
    def __init__(self, csv_path):
        self.csv_path = csv_path
        self.dataframe = pd.read_csv(self.csv_path)
        self.aggregate_results(self.dataframe)
        
    def aggregate_results(self, df):
        main_count_dict = defaultdict(int)
        heatmap_count_dict = defaultdict(int)
        assignment_ids_dict = defaultdict(list)
        
        for index, content in df.iterrows():
            class_index = int(content['Answer.class_index'])
            feature_index = int(content['Answer.feature_index'])
            main_key = str(class_index) + '_' + str(feature_index)
            
            assignment_ids_dict[main_key].append(content['AssignmentId'])

            a1 = content['Answer.a1']
            a2 = content['Answer.a2']
            a3 = content['Answer.a3']
            a4 = content['Answer.a4']
            a5 = content['Answer.a5']

            b1 = content['Answer.b1']
            b2 = content['Answer.b2']
            b3 = content['Answer.b3']
            b4 = content['Answer.b4']
            b5 = content['Answer.b5']

            c1 = content['Answer.c1']
            c2 = content['Answer.c2']
            c3 = content['Answer.c3']
            c4 = content['Answer.c4']
            c5 = content['Answer.c5']
            
            for panel in ['a', 'b', 'c']:
                for i in range(1, 6):
                    heatmap_key = str(panel) + str(i)
                    heatmap_answer = content['Answer.' + heatmap_key]
                    if heatmap_answer == 'different':
                        heatmap_count_dict[main_key + '_' + heatmap_key] += 1
                        
            main_answer = content['Answer.main_question']
            if main_answer == 'same':
                main_count_dict[main_key] = main_count_dict[main_key] + 1

        self.main_count_dict = main_count_dict
        self.heatmap_count_dict = heatmap_count_dict
        self.assignment_ids_dict = assignment_ids_dict
    
    def count_same_heatmaps(self, main_key, min_threshold=3, max_threshold=5):
        num_same_heatmaps = 0
        for panel in ['a', 'b', 'c']:
            for i in range(1, 6):
                heatmap_key = str(panel) + str(i)
                diff_count = self.heatmap_count_dict[main_key + '_' + heatmap_key]
                
                if (diff_count <= (5 - min_threshold)) and (diff_count >= (5 - max_threshold)):
                    num_same_heatmaps += 1
        return num_same_heatmaps

    

def get_feature_type_dict(csv_path='whole_imagenet_results/approved_results_new.csv'):
    complete_df = pd.read_csv(csv_path)
    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    core_features_dict = defaultdict(list)
    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer in ['separate_object', 'background']:
                num_spurious = num_spurious + 1                

        class_index, feature_index, feature_rank = key.split('_')
        class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)
        if num_spurious >= 3:
            spurious_features_dict[class_index].append(feature_index)
        else:
            core_features_dict[class_index].append(feature_index)
            
    return core_features_dict, spurious_features_dict

model_name_list = ['resnet18', 'resnet50', 'vgg19_bn', 'inception_v3_google', 
                   'googlenet', 'shufflenetv2_x1', 'mobilenet_v2', 'mobilenet_v3_large', 
                   'resnext50_32x4d', 'wide_resnet50_2', 'mnasnet1_0', 'efficientnet-b0', 
                   'efficientnet-b4', 'efficientnet-b7', 'clip_vit_b16', 'clip_vit_b32', 
                   'deit_base_patch16_224', 'deit_base_distilled_patch16_224',
                   'vit_base_patch16_224', 'vit_base_patch32_224']



model_name_map = {
    'resnet18': 'Resnet-18',
    'resnet50': 'Resnet-50',
    'vgg19_bn': 'Vgg-19',
    'inception_v3_google': 'Inception-V3',
    'shufflenetv2_x1': 'Shufflenet-V2',
    'mobilenet_v2': 'Mobilenet-V2',
    'mobilenet_v3_large': 'Mobilenet-V3',
    'resnext50_32x4d': 'Resnext50-32x4d',
    'wide_resnet50_2': 'Wide-Resnet50-2',
    'mnasnet1_0': 'MNAS-net-1-0',
    'efficientnet-b0': 'Efficientnet-B0',
    'efficientnet-b4': 'Efficientnet-B4',
    'efficientnet-b7': 'Efficientnet-B7',
    'clip_vit_b16': 'CLIP VIT-B16',
    'clip_vit_b32': 'CLIP VIT-B32',
    'deit_base_patch16_224': 'DEIT-16',
    'deit_base_distilled_patch16_224': 'DEIT-16-distilled',
    'vit_base_patch16_224': 'VIT-B16',
    'vit_base_patch32_224': 'VIT-B32'
}

