
import argparse
from collections import defaultdict
from glob import glob
import json
import numpy as np
from pathlib import Path
import pickle
import sys

sys.path.insert(0, '../Common/')
from Load import load_standard

from Config import get_data_dir, get_out_features, id_from_path

###
# Experiment filtering
###

def extract_info(dataset):
    
    out = {}
    
    af = dataset.get_active_features(remove_defaults = False)
    af.remove(('background', 'presence')) #All images need a background, so we exclude this
    af.remove(('background', 'relative-position')) #This is the 'meta-feature' that we add
    
    b = []
    for v in dataset.blindspots:
        tmp = v.copy()
        tmp.remove(('background', 'presence', 1))
        b.append(tmp)
            
    # Number of simple features
    out['num_features'] = len(af)
    
    # Number of 'Objects' (non-Background Features)
    tmp = set([feature for (feature, option) in af])
    if 'background' in tmp:
        tmp.remove('background')
    out['num_objects'] = len(tmp)
    
    # Number of blindspots
    out['num_blindspots'] = len(b)
    
    # Size of each blindspot
    tmp = []
    for v in b:
        tmp.append(len(v))   
    out['blindspots_size'] = tmp
    
    # Does this blindspot use Relative-Position?
    tmp = []
    for v in b:
        tmp.append(('background', 'relative-position') in [(feature, option) for (feature, option, value) in v])
    out['blindspots_rp'] = tmp
    
    # Does this blindspot use texture?
    tmp = []
    for v in b:
        tmp.append('texture' in [option for (feature, option, value) in v])
    out['blindspots_texture'] = tmp
    
    # Does this blindspot use the presence of a circle?
    tmp = []
    for v in b:
        tmp.append(('circle', 'presence') in [(feature, option) for (feature, option, value) in v])
    out['blindspots_circle'] = tmp
    
    return out

def get_results(directory, method_name, verbose = True):
    name = directory.split('/')[-1]
    
    # Get the output of the method
    with open('./Outputs/{}/{}/map.pkl'.format(name, method_name), 'rb') as f:
        cluster_map = pickle.load(f)
             
    # Load the output used by the method
    out = load_standard('object', '{}/{}'.format(get_data_dir(), name), get_out_features(), base_dir = './Outputs/{}'.format(name), fold = 'test')
    
    # Convert the method output to standard output
    new_map = {}
    for key in cluster_map:
        tmp = []
        for i in cluster_map[key]:
            tmp.append(id_from_path(out['files'][i]))
        new_map[key] = tmp
    cluster_map = new_map

    # Create the map from "blindspot" to "set of points"
    with open('{}/test/images.json'.format(directory), 'r') as f:
        images = json.load(f)

    blindspot_map = defaultdict(list)
    for i, v in enumerate(images):
        contained = images[v]['contained']

        for j in contained:
            blindspot_map[j].append(i)
     
    # Compute the metrics
    return compute_metrics(blindspot_map, cluster_map, verbose = verbose)
   
def filter_results_dataset(filter_func, method_name):
    filtered_results = defaultdict(lambda: defaultdict(list))
    for directory in glob('{}/*'.format(get_data_dir())):
        name = directory.split('/')[-1]
        
        # Get the dataset info
        with open('{}/dataset.pkl'.format(directory), 'rb') as f:
            d = pickle.load(f)
        out = extract_info(d)
        
        # Check the filter conditions
        v = filter_func(name, out)
        if v is not None:
            stdout_orig = sys.stdout
            sys.stdout = open('./Outputs/{}/{}/metrics.txt'.format(name, method_name), 'w')
            out = get_results(directory, method_name)
            sys.stdout = stdout_orig
            filtered_results[v]['dr'].append(out['dr'])
            for metric in ['fdr', 'error_found', 'error_returned', 'error_composed', 'error_mixed']:
                tmp = out[metric]
                if tmp != -1:
                    filtered_results[v][metric].append(tmp)

    out = {}
    for key in filtered_results:
        tmp = {}
        for metric in ['dr', 'fdr', 'error_found', 'error_returned', 'error_composed', 'error_mixed']:
            v = filtered_results[key][metric]
            l = len(v)
            if l > 0:
                tmp[metric] = [np.mean(v), np.std(v) / np.sqrt(l), l]
            else:
                tmp[metric] = [-1]
        out[key] = tmp
            
    return out

def filter_results_blindspot(filter_func, method_name):
    filtered_results = defaultdict(lambda: defaultdict(list))
    for directory in glob('{}/*'.format(get_data_dir())):
        name = directory.split('/')[-1]
        
        # Get the dataset info
        with open('{}/dataset.pkl'.format(directory), 'rb') as f:
            d = pickle.load(f)
        out = extract_info(d)
        
        # Check the filter conditions
        v_list = filter_func(name, out)
        if v_list is not None:
            stdout_orig = sys.stdout
            sys.stdout = open('./Outputs/{}/{}/metrics.txt'.format(name, method_name), 'w')
            out = get_results(directory, method_name)
            sys.stdout = stdout_orig
            for i, v in enumerate(v_list):
                tmp = out['covered'][i]
                filtered_results[v]['covered'].append(tmp)
                if tmp == 1:
                    filtered_results[v]['counts'].append(out['counts'][i])
    
    out = {}
    for key in filtered_results:
        tmp = {}
        for metric in ['covered', 'counts']:
            v = filtered_results[key][metric]
            l = len(v)
            if l > 0:
                tmp[metric] = [np.mean(v), np.std(v) / np.sqrt(l), l]
            else:
                tmp[metric] = [-1]
        out[key] = tmp
            
    return out

def print_results(results):
    keys = sorted(list(results))
    for key in keys:
        tmp = {}
        for name in results[key]:
            tmp[name] = np.round(results[key][name][0], 2)
        print(key, tmp) 
   
###
# Metrics
###

def compute_metrics(blindspot_map, group_map, thresh_p = 0.8, thresh_r = 0.8, verbose = True, exclude_found = True):
    num_blindspots = len(blindspot_map)
    num_groups = len(group_map)
    
    ###
    # Main Metrics
    ###
    
    # Blindspot Precision and Belongs To

    bp = np.zeros((num_blindspots, num_groups))
    for i, v_i in enumerate(blindspot_map):
        for j, v_j in enumerate(group_map):
            bp[i, j] = len(np.intersect1d(blindspot_map[v_i], group_map[v_j])) / len(group_map[v_j])
            
    belong = 1 * (bp >= thresh_p)
    
    fp = 1 * (np.max(bp, axis = 0) < thresh_p)
          
    if verbose:
        print('Blindspot Precision')
        print(np.round(bp, 2))
        print('Belongs To')
        print(belong)
        print('False Positive')
        print(fp)
        print()
    
    # Naive Blindspot Recall
    
    nbr = np.zeros((num_blindspots, num_groups))
    for i, v_i in enumerate(blindspot_map):
        for j, v_j in enumerate(group_map):
            nbr[i, j] = len(np.intersect1d(blindspot_map[v_i], group_map[v_j])) / len(blindspot_map[v_i]) 
                        
    if verbose:
        print('Naive Blindspot Recall')
        print(np.round(nbr, 2))
        print()
        
    # Blindspot Recall, Covered, and Discovery Rate
    # Note:  we used the smallest number of hypothesized blindspots possible, so the value for BR is a lower bound
    
    br =  np.zeros((num_blindspots))
    depth = -1 * np.ones((num_blindspots))
    counts = np.zeros((num_blindspots))
    for i, v_i in enumerate(blindspot_map):
        agg = []
        for j, v_j in enumerate(group_map):
            if bp[i, j] >= thresh_p:
                agg.extend(group_map[v_j])
                v = len(np.intersect1d(blindspot_map[v_i], agg)) / len(blindspot_map[v_i])
                br[i] = v
                depth[i] = j
                counts[i] += 1
                if v >= thresh_r:
                    break
    covered = 1 * (br >= thresh_r)
    
    dr = np.mean(covered)
                   
    if verbose:
        print('Blindspot Recall')
        print(np.round(br, 2))
        print('Covered')
        print(covered)
        print('Depth')
        print(depth)
        print('Counts')
        print(counts)
        print('Discovery Rate')
        print(np.round(dr, 2))
        print()
        
    # False Discovery Rate
    if dr != 0.0:
        depth_max = int(np.max(depth * covered))
        fdr = np.mean(fp[:(depth_max + 1)])
    else:
        depth_max = -1
        fdr = -1
        
    if verbose:
        print('Max Depth')
        print(depth_max)
        print('False Discovery Rate')
        print(np.round(fdr, 2))
        print()
    
    ###
    # Metrics calculated in order to understand why methods fail
    ###
        
    percentages = np.zeros((num_blindspots))
    for i, v_i in enumerate(blindspot_map):
        agg = []
        for j, v_j in enumerate(group_map):
            if bp[i, j] >= thresh_p:
                agg.extend(group_map[v_j])
        percentages[i] = len(np.intersect1d(blindspot_map[v_i], agg)) / len(blindspot_map[v_i])
    if not exclude_found:
        error_found = np.mean(percentages)
    else:
        if dr != 1.0:
            error_found = np.sum((1 - covered) * percentages) / np.sum(1 - covered)
        else:
            error_found = -1.0
        
    # What percentage of the images in each true-blindspot were not part of any hypothesized-blindspot
    percentages =  np.zeros((num_blindspots))
    for i, v_i in enumerate(blindspot_map):
        agg = []
        for j, v_j in enumerate(group_map):
            agg.extend(group_map[v_j])
        percentages[i] = 1 - len(np.intersect1d(blindspot_map[v_i], agg)) / len(blindspot_map[v_i])
    
    if not exclude_found:
        error_returned = np.mean(percentages)
    else:
        if dr != 1.0:
            error_returned = np.sum((1 - covered) * percentages) / np.sum(1 - covered)
        else:
            error_returned = -1.0
    
    # By summing along this column, we find the fraction of images in each hypothesized-blindspot that belong to *any* true-blindspot
    total = np.sum(bp, axis = 0) 
    
    # What percentage of the images in each true-blindspot belonged to a hypothesized-blindspot that was a composed of multiple true-blindspots?
    composed = fp * (total >= thresh_p)
    
    percentages = np.zeros((num_blindspots))
    for i, v_i in enumerate(blindspot_map):
        agg = set()
        
        for j, v_j in enumerate(group_map):
            if composed[j] == 1:
                agg.update(set(group_map[v_j]))
                
        for j, v_j in enumerate(group_map):
            if composed[j] == 0:
                agg = agg.difference(set(group_map[v_j]))
                
        agg = list(agg)
        
        percentages[i] = len(np.intersect1d(blindspot_map[v_i], agg)) / len(blindspot_map[v_i])
    
    if not exclude_found:
        error_composed = np.mean(percentages)
    else:
        if dr != 1.0:
            error_composed = np.sum((1 - covered) * percentages) / np.sum(1 - covered)
        else:
            error_composed = -1.0

    if verbose:
        print('Composed Indicators')
        print(composed)
        print('Percentages')
        print(np.round(percentages, 2))
        print('Score')
        print(np.round(error_composed, 2))
        print()
    
    # What percentage of the images in each true-blindspot belonged to a hypothesized-blindspot that was a mixture of images from true-blindspots and images not from any true-blindspot?
    mixed = fp * (total < thresh_p)
    
    percentages = np.zeros((num_blindspots))
    for i, v_i in enumerate(blindspot_map):
        agg = set()
        
        for j, v_j in enumerate(group_map):
            if mixed[j] == 1:
                agg.update(set(group_map[v_j]))
                
        for j, v_j in enumerate(group_map):
            if mixed[j] == 0:
                agg = agg.difference(set(group_map[v_j]))
                
        agg = list(agg)
                
        percentages[i] = len(np.intersect1d(blindspot_map[v_i], agg)) / len(blindspot_map[v_i])
     
    if not exclude_found:
        error_mixed = np.mean(percentages)
    else:
        if dr != 1.0:
            error_mixed = np.sum((1 - covered) * percentages) / np.sum(1 - covered)
        else:
            error_mixed = -1.0

    if verbose:
        print('Mixed Indicators')
        print(composed)
        print('Percentages')
        print(np.round(percentages, 2))
        print('Score')
        print(np.round(error_mixed, 2))
        print()

    ###
    # Setup the output
    ###
        
    out = {}
    out['dr'] = dr
    out['fdr'] = fdr
    out['covered'] = covered
    out['counts'] = counts
    
    out['error_found'] = error_found
    out['error_returned'] = error_returned
    out['error_composed'] = error_composed
    out['error_mixed'] = error_mixed
    
    return out

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description = 'Runs the analysis of the synthetic results')
    parser.add_argument('--method', type = str, default = '')
    parser.add_argument('--mode', type = str, default = '')
    args = parser.parse_args()
    
    method = args.method
    mode = args.mode
    
    if mode == 'hps':
        # Configurations used for Hyper Parameter Search (1-20)
        def check_name(name, val):
            split = name.split('-')
            mode = split[0]
            num = int(split[1])
            if mode != 'complex' or num > 20:
                return None
            return val
    elif mode == 'complex':
        # Configurations used for the results of the Complex configurations (21-120)
        def check_name(name, val):
            split = name.split('-')
            mode = split[0]
            num = int(split[1])
            if mode != 'complex' or num <= 20 or num > 120:
                return None
            return val
    elif mode == 'dc':
        def check_name(name, val):
            split = name.split('-')
            mode = split[0]
            if mode != 'dc':
                return None
            return val
    else:
        print('Bad "mode" parameter')
        sys.exit()
        
    def wrap(check_info):
        def f(name, info):
            return check_name(name, check_info(info))
        return f

    out_dir = './Outputs/analysis/{}'.format(mode)
    Path(out_dir).mkdir(parents = True, exist_ok = True)
    out_file = '{}/{}.txt'.format(out_dir, method)
    stdout_orig = sys.stdout
    sys.stdout = open(out_file, 'w')
    out = {}
    
    # TODO: this is slow because we re-compute all the metrics for every query
        
    print('Dataset Queries')
    print()
    tests = {
        'Average': lambda info: 'all',
        'Num Blindspots': lambda info: info['num_blindspots'],
        'Num Dataset Features': lambda info: info['num_features']}
    for label, func in tests.items():
        print(label)
        v = filter_results_dataset(wrap(func), method)
        out[label] = v
        print_results(v)
        print()
    print()
    print('Blindspot Queries')
    print()
    tests = {
        'Average_blindspots': lambda info: [0 for i in info['blindspots_size']],
        'Num Blindspot Features': lambda info: info['blindspots_size'],
        'Blindspot uses "relative position"': lambda info: info['blindspots_rp'],
        'Blindspot uses "texture"': lambda info: info['blindspots_texture'],
        'Blindspot uses "presence of circle"': lambda info: info['blindspots_circle']}
    for label, func in tests.items():
        print(label)
        v = filter_results_blindspot(wrap(func), method)
        out[label] = v
        print_results(v)
        print()
    print()
    
    sys.stdout = stdout_orig
        
    with open('{}/{}.pkl'.format(out_dir, method), 'wb') as f:
        pickle.dump(out, f)
    