from matplotlib import pyplot as plt
import random
import os
import sys
import json
import shutil
from collections import defaultdict
import pandas as pd
import numpy as np
import cv2
import torch as ch

from robustness.tools.vis_tools import show_image_row, show_image_column


def get_failing_classes():
    imagenet_metadata_dir = './imagenet_features'

    nonrobust_logits_file = os.path.join(imagenet_metadata_dir, 'nonrobust_train_logits.npy')
    robust_logits_file = os.path.join(imagenet_metadata_dir, 'robust_train_logits.npy')
    labels_file = os.path.join(imagenet_metadata_dir, 'train_labels.npy')

    nonrobust_logits = np.load(nonrobust_logits_file)
    nonrobust_preds = np.argmax(nonrobust_logits, axis=1)
    robust_logits = np.load(robust_logits_file)
    robust_preds = np.argmax(robust_logits, axis=1)
    labels = np.load(labels_file)




    num_classes = 1000
    failures_robust_labels = np.zeros(num_classes)
    failures_nonrobust_labels = np.zeros(num_classes)
    failures_robust_preds = np.zeros(num_classes)
    failures_nonrobust_preds = np.zeros(num_classes)
    robust_preds_arr = np.zeros(num_classes)
    nonrobust_preds_arr = np.zeros(num_classes)
    images_per_class = np.zeros(num_classes)


    for i in range(1000):
        failures_robust_preds[i] = np.sum((robust_preds == i) & (np.logical_not(robust_preds == labels)))
        failures_nonrobust_preds[i] = np.sum((nonrobust_preds == i) & (np.logical_not(nonrobust_preds == labels)))

        failures_robust_labels[i] = np.sum((labels == i) & (np.logical_not(robust_preds == labels)))
        failures_nonrobust_labels[i] = np.sum((labels == i) & (np.logical_not(nonrobust_preds == labels)))

        images_per_class[i] = np.sum(labels==i)

        robust_preds_arr[i] = np.sum(robust_preds == i)
        nonrobust_preds_arr[i] = np.sum(nonrobust_preds == i)
        
    failure_ratio_robust_preds = failures_robust_preds/np.float32(robust_preds_arr)
    failure_ratio_nonrobust_preds = failures_nonrobust_preds/np.float32(nonrobust_preds_arr)

    sorted_failure_ratios_robust_preds = np.argsort(-failure_ratio_robust_preds)
    sorted_failure_ratios_nonrobust_preds = np.argsort(-failure_ratio_nonrobust_preds)

    most_failing_classes_robust_preds = sorted_failure_ratios_robust_preds[:50]
    least_failing_classes_robust_preds = sorted_failure_ratios_robust_preds[-50:]

    most_failing_classes_nonrobust_preds = sorted_failure_ratios_nonrobust_preds[:50]
    least_failing_classes_nonrobust_preds = sorted_failure_ratios_nonrobust_preds[-50:]


    failure_ratio_robust_labels = failures_robust_labels/np.float32(images_per_class)
    failure_ratio_nonrobust_labels = failures_nonrobust_labels/np.float32(images_per_class)

    sorted_failure_ratios_robust_labels = np.argsort(-failure_ratio_robust_labels)
    sorted_failure_ratios_nonrobust_labels = np.argsort(-failure_ratio_nonrobust_labels)

    most_failing_classes_robust_labels = sorted_failure_ratios_robust_labels[:50]
    least_failing_classes_robust_labels = sorted_failure_ratios_robust_labels[-50:]

    most_failing_classes_nonrobust_labels = sorted_failure_ratios_nonrobust_labels[:50]
    least_failing_classes_nonrobust_labels = sorted_failure_ratios_nonrobust_labels[-50:]

    
    li = [most_failing_classes_robust_preds, least_failing_classes_robust_preds, 
          most_failing_classes_nonrobust_preds, least_failing_classes_nonrobust_preds, 
          most_failing_classes_robust_labels, least_failing_classes_robust_labels, 
          most_failing_classes_nonrobust_labels, least_failing_classes_nonrobust_labels]
    return li



def print_with_stars(print_str, total_count=115, prefix="", suffix="", star='*'):
    str_len = len(print_str)
    left_len = (total_count - str_len)//2
    right_len = total_count - left_len - str_len
    final_str = "".join([star]*(left_len)) + print_str + "".join([star]*(right_len))
    final_str = prefix + final_str + suffix
    print(final_str)


def aggregate_results(df):
    class_info_file = open('./data/class_info.json')
    class_info_data = json.load(class_info_file)

    synset_dict = {}
    gloss_dict = defaultdict(str)
    superclasses_dict = defaultdict(str)
    for class_dict in class_info_data:
        class_id = class_dict['cid']    
            
        synset = str(', '.join(class_dict['synset']))
        synset_dict[class_id] = synset    
        
        gloss = class_dict['gloss']
        gloss_dict[class_id] = gloss
        
    superclasses = np.load('data/superclasses.npy', allow_pickle=True).item()
    for superclass, class_list in superclasses.items():
        for class_id in class_list:
            superclasses_dict[class_id] = superclass




    common_reasons = []

    blocked_workers = []

    inputs_dict = {}
    answers_dict = defaultdict(list)
    
    for row in df.iterrows():
        index, content = row
        WorkerId = content['WorkerId']
        
        class_index = int(content['Answer.class_index'])
        feature_index = int(content['Answer.feature_index'])
        feature_rank = int(content['Input.feature_rank'])
        
        main_index = str(class_index) + '_' + str(feature_index) + '_' + str(feature_rank)
        
        
        main_answer = content['Answer.main_question']
        confidence = content['Answer.confidence']

        reasons = content['Answer.reasons']
        if (reasons.lower() in common_reasons):
            reasons = ""
            
        if not(content['AssignmentStatus'] == "Rejected") and not(WorkerId in blocked_workers):
            answers_dict[main_index].append((WorkerId, main_answer, confidence, reasons))
            
    return answers_dict

def get_answers_dict():
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    answers_dict = aggregate_results(complete_df)
    return answers_dict
    
def get_causal_features_dict(threshold=3, return_rank=False, validate_heatmap=False):
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    causal_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_causal = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer == 'main_object':
                num_causal = num_causal + 1
        
        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        if num_causal >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if return_rank:
                causal_features_dict[class_index].append((feature_index, feature_rank))
            else:
                causal_features_dict[class_index].append(feature_index)
            
    return causal_features_dict
    
def get_spurious_features_dict(threshold=3, return_rank=False, validate_heatmap=True):    
    same_dict = aggregate_heatmap_validation()

    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer in ['separate_object', 'background']:
                num_spurious = num_spurious + 1

        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        append_feature = True
        if validate_heatmap and (same_dict[new_key] <= 0.5):
            append_feature = False
                
        if num_spurious >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if append_feature:
                if return_rank:
                    spurious_features_dict[class_index].append((feature_index, feature_rank))
                else:
                    spurious_features_dict[class_index].append(feature_index)
            
    return spurious_features_dict

def get_specific_spurious_dict(threshold=3, return_rank=False, validate_heatmap=False, choice='background'):
    assert choice in ['background', 'separate_object']
    
    same_dict = aggregate_heatmap_validation()

    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    class_info_file = open('./data/class_info.json')

    root_dir = 'feature_visualization'
    answers_dict = aggregate_results(complete_df)

    spurious_features_dict = defaultdict(list)
    for key, answers in answers_dict.items():
        num_spurious = 0
        for answer in answers:
            main_answer = answer[1]
            if main_answer == choice:
                num_spurious = num_spurious + 1

        class_index, feature_index, feature_rank = key.split('_')
        new_key = class_index + '_' + feature_index
        
        append_feature = True
        if validate_heatmap and (same_dict[new_key] <= 0.5):
            append_feature = False
                
        if num_spurious >= threshold:
            class_index, feature_index, feature_rank = int(class_index), int(feature_index), int(feature_rank)

            if append_feature:
                if return_rank:
                    spurious_features_dict[class_index].append((feature_index, feature_rank))
                else:
                    spurious_features_dict[class_index].append(feature_index)
            
    return spurious_features_dict



def get_unique_workers():
    complete_df_file_name = 'complete_results.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    worker_ids = np.array(complete_df['WorkerId'])
    worker_ids_unique = np.unique(worker_ids)
    return worker_ids_unique, len(complete_df)

def aggregate_heatmap_validation():
    complete_df_file_name = 'validation_results_complete.csv'
    complete_df_file_path = os.path.join('final_batch_results', complete_df_file_name)
    complete_results_file = open(complete_df_file_path)
    complete_df = pd.read_csv(complete_results_file)

    common_reasons = []
    blocked_workers = []

    inputs_dict = {}
    answers_dict = defaultdict(list)
    
    for row in complete_df.iterrows():
        index, content = row
        WorkerId = content['WorkerId']
        
        class_index = int(content['Input.class_id'])
        feature_index = int(content['Input.feature_id'])
        main_index = str(class_index) + '_' + str(feature_index)
        
        main_answer = content['Answer.main_question']
        confidence = content['Answer.confidence']
        reasons = content['Answer.reasons']
            
        if not(content['AssignmentStatus'] == "Rejected") and not(WorkerId in blocked_workers):
            answers_dict[main_index].append((WorkerId, main_answer, confidence, reasons))
            
    same_dict = defaultdict(float)
    for key, answers in answers_dict.items():
        class_index, feature_index = key.split('_')
        class_index, feature_index = int(class_index), int(feature_index)

        total = 0
        num_same = 0
        for answer in answers:
            WorkerId, main_answer, confidence, reason = answer

            if main_answer == 'same':
                num_same = num_same + 1
            total = total + 1

        same_dict[key] = num_same/total
    return same_dict
