import sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import utils
import glob
import json
import re
import pandas as pd
import pathlib
import datetime

from fnmatch import filter
from collections import defaultdict
from tqdm import tqdm

from research_pool.config import dataset_to_batch, \
                                 order_to_dataset_epsilon, dataset_to_resizings, \
                                 robust_l2_mnist, robust_l2_cifar, robust_l2_imagenet, robust_linf_cifar, robust_linf_imagenet



def safe_load(o_path):
    if not os.path.exists(o_path):
        print(f"Warn: {o_path} was not found!")
        return None
    else:
        o = np.load(o_path, allow_pickle=True)
        return o

def expand_maybe(x):
    if len(x.shape) < 4:
        x = np.expand_dims(x, 0)
    return x


class AdversarialSample(object):
    def __init__(self, adv_path, benign_path, trajectory_path, gradients_path):
        self.adv_path = adv_path
        self.pathlib_obj = pathlib.Path(adv_path)
        self.mod_time = datetime.datetime.fromtimestamp(self.pathlib_obj.stat().st_mtime)
        self.benign_path = benign_path
        self.trajectory_path = trajectory_path
        self.gradients_path = gradients_path
        
    def get_adv(self):
        return safe_load(self.adv_path)
    
    def get_benign(self):
        return safe_load(self.benign_path)
    
    def get_trajectory(self):
        return safe_load(self.trajectory_path)
    
    def get_gradients(self):
        return safe_load(self.gradients_path)
    

att_pretty = {
    "OPT_attack": "OPT", 
    "RandSampling_OPT_attack": "Rand+OPT", 
    "Sampling_OPT_attack": "BiLN+OPT", 
    "HLM_OPT_attack": "AE+OPT", 
    "Sign_OPT": "Sign-OPT",
    "RandSampling_Sign_OPT": "Rand+Sign-OPT",
    "Sampling_Sign_OPT": "BiLN+Sign-OPT",
    "HLM_Sign_OPT": "AE+Sign-OPT",
    "RayS": "RayS",
    "Sampling_RayS": "BiLN+RayS",
    "HLM_RayS": "AE+RayS",
    "HSJA": "HSJA",
    "Sampling_HSJA": "BiLN+HSJA",
    "RandSampling_HSJA": "Rand+HSJA",
    "HLM_HSJA": "AE+HSJA",
}

# Names for tables
att_pretty_flipped = {
    "OPT_attack": "OPT", 
    "RandSampling_OPT_attack": "OPT+Rand",
    "Sampling_OPT_attack": "OPT+BiLN",
    "HLM_OPT_attack": "OPT+AE",
    "Sign_OPT": "Sign-OPT",
    "RandSampling_Sign_OPT": "Sign-OPT+Rand",
    "Sampling_Sign_OPT": "Sign-OPT+BiLN",
    "HLM_Sign_OPT": "Sign-OPT+AE",
    "RayS": "RayS",
    "Sampling_RayS": "RayS+BiLN",
    "HLM_RayS": "RayS+AE",
    "HSJA": "HSJA",
    "Sampling_HSJA": "HSJA+BiLN",
    "RandSampling_HSJA": "HSJA+Rand",
    "HLM_HSJA": "HSJA+AE",
}

att_order = {
    "HSJA": 0,
    "Sampling_HSJA": 1,
    "RandSampling_HSJA": 2,
    "HLM_HSJA": 3,
    "OPT_attack": 4, 
    "Sampling_OPT_attack": 5, 
    "HLM_OPT_attack": 6, 
    "Sign_OPT": 7,
    "Sampling_Sign_OPT": 8,
    "RandSampling_Sign_OPT": 9,
    "HLM_Sign_OPT": 10,
    "RayS": 11,
    "Sampling_RayS": 12,
    "HLM_RayS": 13,
}

oranges = ['orange', 'goldenrod', 'gold']

att_colors = {
    "OPT_attack": 'dodgerblue', 
    "Sampling_OPT_attack": 'orange', 
    "HLM_OPT_attack": 'green', 
    "Sign_OPT": 'dodgerblue',
    "Sampling_Sign_OPT": 'orange',
    "RandSampling_Sign_OPT": 'orange',
    "HLM_Sign_OPT": 'green',
    "RayS": 'dodgerblue',
    "Sampling_RayS": 'orange',
    "HLM_RayS": 'green',
    "HSJA": 'dodgerblue',
    "Sampling_HSJA": 'orange',
    "RandSampling_HSJA": 'orange',
    "HLM_HSJA": 'green',
}
  
dataset_pretty = {
    'MNIST': 'Natural',
    'MNIST_madry': 'Madry Adv. Tr',
    'MNIST_trades': 'TRADES',
    'MNIST_rob_manifold': 'Robust Manifold',
    'MNIST_deep_camma': 'Deep Camma (MNIST)',
    'CIFAR10': 'Natural CIFAR-10',
    'CIFAR10_madry': 'Madry CIFAR-10',
    'CIFAR10_trades': 'TRADES',
    'CIFAR10_sense': 'SENSE',
    'CIFAR10_fs': 'Feat. Scattering',
    'CIFAR10_interp': 'Interpolation',
    'CIFAR10_smooth110': 'Smoothing',
    'Imagenet': 'Natural ImageNet',
    'Imagenet_smooth50': 'Smoothing ImageNet',
    'Imagenet_madry8': 'Madry ImageNet',
    'Imagenet_madry4': 'Madry ImageNet',
}


class ExperimentState(object):
    def __init__(self, state):
        for key in state:
            setattr(self, key, state[key])
        self.yi_succ = []
        self.iter_succ = []
        self.init_database()
        self.attack_pretty = att_pretty[self.attack]
        self.dataset_pretty = dataset_pretty[self.dataset]
        self.extra_naming()
        self.color = None
        self.recreate_log()
    
    def is_targeted(self):
        return self.targeted
            
    def get_outdir(self):
        return os.path.join(self.save_dir, self.attack, self.dataset, str(self.run_id))
    
    def init_database(self):
        self.sample_database = {}
        base_path = self.get_outdir()
        
        images_folder = os.path.join(base_path, 'images')
        if not os.path.exists(images_folder):
            return
        yi_avail = os.listdir(images_folder)
        self.yi_succ = list(set([int(s) for s in yi_avail]))
        iter_succ = set()
        
        for yi in yi_avail:
            iters_avail = os.listdir(os.path.join(images_folder, yi))
            # ia is unique for every attack
            for ia in iters_avail:
                ia_folder = os.path.join(images_folder, yi, ia)
                if len(os.listdir(ia_folder)) == 0:
                    continue
                    
                adv_files = filter(os.listdir(ia_folder), "*adv_*")
                adv_ = [os.path.join(ia_folder, advf) for advf in adv_files]
                ben_ = os.path.join(ia_folder, "benign.npy")
                trajectory_ = os.path.join(ia_folder, "trajectory.npy")
                gradients_ = os.path.join(ia_folder, "gradients.npy")
                
                iax = int(ia.replace('iter=', ''))
                iter_succ.add(iax)
                self.sample_database[f"{yi}.{iax}"] = []
                for adv in adv_:
                    if adv is None:
                        continue
                    self.sample_database[f"{yi}.{iax}"].append(AdversarialSample(adv, ben_, trajectory_, gradients_))
        
        
        self.iter_succ = list(iter_succ)
        
    def recreate_log(self):
        # if 'HSJA' in self.attack:
            # reset = True
        # else:
        reset = False
        log_means_path = os.path.join(self.byproduct_dir, f"{str(self.run_id)}_means.npy")
        log_stds_path = os.path.join(self.byproduct_dir, f"{str(self.run_id)}_stds.npy")
        
        if not os.path.exists(log_means_path) or not os.path.exists(log_stds_path) or reset:
            print(f"Recreating log statistics id={self.run_id} in {self.byproduct_dir}")
            samples = self.all_samples()
            b = len(samples)
            # assert b > 0

            # B x T x dist
            exp_log = np.zeros((b, self.query_limit, 1))

            # Trajectory scan
            for i, sample_obj in tqdm(enumerate(samples)):
                
                try:
                    trajectory = sample_obj.get_trajectory()    
                except OSError:
                    print(f'Warn: Bad trajectory in {st.dataset} - {st.attack}')
                    continue

                dists = trajectory[:, 0]
                queries = trajectory[:, 1]

                for j in range(0, len(dists)):
                    if queries[j] >= self.query_limit:
                        break
                        
                    if j == 0:
                        exp_log[i, :] = dists[j]
                    else:
                        exp_log[i, queries[j]:] = dists[j]

            exp_mean = np.mean(exp_log, axis=0)
            exp_std = np.std(exp_log, axis=0)
            
            np.save(log_means_path, exp_mean)
            np.save(log_stds_path, exp_std)
        else:
            print(f"Load statistics from {self.byproduct_dir}")
            exp_mean = np.load(log_means_path)
            exp_std = np.load(log_stds_path)
            

        self.means = exp_mean
        self.stds = exp_std
        
    def extra_naming(self):
        if self.attack == "Sampling_RayS":
            if self.a == 1:
                self.attack_pretty = f"{self.attack_pretty} b={self.b}"
            if self.b == 1:
                self.attack_pretty = f"{self.attack_pretty} a={self.a}"
                
        elif "Sampling_" in self.attack:
            self.attack_pretty = f"{self.attack_pretty} {self.resize_dim}"
    
    def rand_yi_samples(self, yi):
        avail = filter(list(self.sample_database.keys()), f"{str(yi)}.*")
        if len(avail) == 0:
            return None
        return self.sample_database[np.random.choice(avail)]
    
    def get_sample(self, i=0, k=0):
        # i=class index, k=sample index
        if len(self.yi_succ) == 0:
            return None
        yi = self.yi_succ[i]
        return self.get_yi_sample(yi, k=k)
    
    def get_yi_sample(self, yi, j=0, k=0):
        avail = filter(list(self.sample_database.keys()), f"{str(yi)}.*")
        if len(avail) == 0:
            return None
        return self.sample_database[avail[j]][k]
    
    def rand_iter_samples(self, i):
        avail = filter(list(self.sample_database.keys()), f"*.{str(i)}")
        if len(avail) == 0:
            return None
        return self.sample_database[np.random.choice(avail)]
    
    def all_yi_sample(self, yi):
        avail = filter(list(self.sample_database.keys()), f"{str(yi)}.*")
        if len(avail) == 0:
            return None
        return [self.sample_database(k) for k in avail]
    
    def all_samples(self):
        samples = []
        for k in self.sample_database:
            samples.extend(self.sample_database[k])
        return samples
    
    def all_latest_samples(self):
        samples = []
        for k in self.sample_database:
            samples.extend(self.sample_database[k])
            
        samples = sorted(samples, key=lambda x: x.mod_time, reverse=True)
        return samples
    
    def get_log(self):
        base_path = self.get_outdir()
        base_name = f"{self.dataset}-{self.attack}_log.npy"
        o_path = os.path.join(base_path, base_name)
        return safe_load(o_path)
    
    def get_all_log(self):
        base_path = self.get_outdir()
        base_name = f"{self.dataset}-{self.attack}_all_log.npy"
        o_path = os.path.join(base_path, base_name)
        return safe_load(o_path)
    
    def get_mean_std(self):
        return self.means, self.stds
    
    def set_color(self, c):
        self.color = c


class StateDatabase(object):
    def __init__(self, args, configs, filters):
        self.filters = filters
        self.args = args
        self.store = defaultdict(list)
        self.datasets = set()
        self.attacks = set()
        self.normal_attacks = set()
        self.test_batch = None
        self.norm = None
        self.targeted = None
        self.early_stop = None
        self.dataset_to_yi = {}
        self.dataset_to_iter = {}
        self.init_discovery(configs)

    def init_discovery(self, configs):
        ds_yi = defaultdict(list)
        ds_iter = defaultdict(list)
        
        for state_file in configs:
            with open(state_file) as state_f:
                state = json.load(state_f)
                
            
            if state['attack'] == 'HLM_HSJA' or state['attack'] == 'HLM_RayS':
                continue
            
            if "Sampling_" in state['attack'] and "MNIST" in state['dataset']:
                if "RayS" not in state['attack']:
                    if state['resize_dim'] not in [7, 14]:
                        continue
                # continue
            
            if 'OPT_attack' in state['attack']:
                continue
            # if 'MNIST' in state['dataset']:
            #     continue
                
            if self.filters[0]:
                # cifar10 + madry
                if 'interp' in state['dataset']:
                    continue
                if 'trade' in state['dataset']:
                    continue
                if 'sense' in state['dataset']:
                    continue
                if '_fs' in state['dataset']:
                    continue
                    
#             if self.filters[1]:
#                 # imagenet debug
#                 if 'HSJA' in state['attack']:
#                     continue
            if not self.filters[1]:
                # no RayS
                if "RayS" in state['attack']:
                    continue
            else: 
                # exclude non-RayS
                if "RayS" not in state['attack']:
                    continue
                    
            if self.filters[2]:
                # cifar10 appendix
                if state['dataset'] == 'CIFAR10':
                    continue
                if 'madry' in state['dataset']:
                    continue

            # legacy support
            if 'a' not in list(state.keys()):
                state['a'] = 1
                state['b'] = 1
            
            if self.test_batch is None:
                self.test_batch = state['test_batch']
            if self.norm is None:
                self.norm = state['order']
            if self.targeted is None:
                self.targeted = state['targeted']
            if self.early_stop == None:
                if 'early_stopping' not in list(state.keys()):
                    state['early_stopping'] = False  # legacy
                self.early_stop = state['early_stopping']
            
            # post setting filter
            # if self.norm != 2 and ('Imagenet' in state['dataset'] or 'MNIST' in state['dataset']):
            #     continue
            if self.norm == 2 and 'CIFAR10' in state['dataset']:
                continue
            
            self.datasets.add(state['dataset'])
            self.attacks.add(state['attack'])
            
            state['byproduct_dir'] = self.args['byproduct_dir']
            if '092620-013652' in state['save_dir']:
                state['save_dir'] = f"{state['save_dir']}_orig"  # Pull from x100 sample dir
                
            print(state['save_dir'])
            
            st = ExperimentState(state)
            self.store[(state['attack'], state['dataset'])].append(st)
            
            ds_yi[state['dataset']].append(set(st.yi_succ))
            ds_iter[state['dataset']].append(set(st.iter_succ))
        
        # deterministic sort of dataset
        ds_order = {'CIFAR10': 0}
        for i, k in enumerate(robust_linf_cifar): 
            ds_order[k] = i + 1
        ix = len(list(ds_order.keys()))
        for i, k in enumerate(robust_l2_cifar):
            ds_order[k] = ix + i + 1
        ix = len(list(ds_order.keys()))
        ds_order['Imagenet'] = ix
        for i, k in enumerate(robust_l2_imagenet):
            ds_order[k] = ix + i + 1
        for i, k in enumerate(robust_linf_imagenet):
            ds_order[k] = ix + i + 1
        ds_order['MNIST'] = ix
        for i, k in enumerate(robust_l2_mnist):
            ds_order[k] = ix + i + 1
            
        def keyfn(ds_name):
            return ds_order[ds_name]
        
        # print(ds_order)
            
        self.datasets = sorted(list(self.datasets), key=keyfn)
        
        for dataset in list(self.datasets):
            self.dataset_to_yi[dataset] = set.intersection(*ds_yi[dataset])
            self.dataset_to_iter[dataset] = set.intersection(*ds_iter[dataset])
                    
        for attack in self.attacks:
            if "RandSampling_" in attack:
                self.normal_attacks.add(attack.replace("RandSampling_", ""))
            elif "Sampling_" in attack:
                self.normal_attacks.add(attack.replace("Sampling_", ""))
            elif "HLM_" in attack:
                self.normal_attacks.add(attack.replace("HLM_", ""))
            else:
                self.normal_attacks.add(attack)
        
        # determinstic sort
        det_order = {'OPT_attack': 2, 'Sign_OPT': 3, 'HSJA': 1, 'RayS': 0}
        def keyfn(att_name):
            return det_order[att_name]
        self.normal_attacks = list(sorted(self.normal_attacks, key=keyfn))
        
        # Colors
        for dataset in list(self.datasets):
            # Sorted
            states = self.states_at(dataset)
            for state in states:
                # biln_counter = 0
                if "RandSampling_" in state.attack:
                    state.set_color('darkgoldenrod')
                elif "Sampling_" in state.attack:
                    state.set_color(oranges[0])
                    # state.set_color(oranges[biln_counter % len(oranges)])
                    # biln_counter += 1
                else:
                    state.set_color(att_colors[state.attack])
                
    def _keys_at(self, dataset, attack):
        # k = (attack, dataset)
        dataset_keys = [k for k in list(self.store.keys()) if dataset in k]
        if attack is not None:
            all_keys = [k for k in dataset_keys if attack in k[0]]
        else:
            all_keys = dataset_keys
            
        return all_keys
            
    def sum_at(self, dataset, attack=None):
        # How many states at the above?
        out = 0

        all_states = self._keys_at(dataset, attack)
        for k in all_states:
            out += len(self.store[k])
        
        return out
    
    def states_at(self, dataset, attack=None):
        out = []
        all_states = self._keys_at(dataset, attack)
        for k in all_states:
            out.extend(self.store[k])
        #deterministic sort
        def keyfn(st):
            if 'Sampling_' in st.attack:
                if st.a > 1: # a
                    return 10 + st.a
                elif st.b > 1:  # b
                    return 20 + st.b
                else:
                    return 5
            elif 'HLM_' in st.attack:
                # last
                return 30
            else:
                # always first
                return 0
        
        out = sorted(out, key=keyfn)
        return out

    
def score_dataframe_to_csv(sdb: StateDatabase, df: pd.DataFrame, score_name: str):
    """
    df will have 'dataset' (pretty), 'attack' (pretty), and '<score_name>' as columns
    """
    ds_to_max_val = {}
    for dataset in df['dataset'].unique():
        ds_to_max_val[dataset] = df[df['dataset'] == dataset][score_name].max()

    print(ds_to_max_val)

    lines = {-1: f"Attack Variant,{','.join([dataset_pretty[d] for d in sdb.datasets])}\n"}  # header
    normalized_lines = {-1: f"Attack Variant,{','.join([dataset_pretty[d] for d in sdb.datasets])}\n"}

    for attack in df['attack'].unique():
        no_num = attack
        se = re.search(" [0-9]+", attack)
        if se:
            val = se.group()
            no_num = attack.replace(val, '')
        
        ugly_attack = {value:key for key, value in att_pretty.items()}[no_num]
        weight = att_order[ugly_attack]
        # print(weight)

        s = [att_pretty_flipped[ugly_attack]]  # row
        ns = [att_pretty_flipped[ugly_attack]]
        for dataset in df['dataset'].unique():
            print(attack, dataset)
            score = df[(df['attack']==attack) & (df['dataset']==dataset)][score_name].iloc[0]
            normalized = score / ds_to_max_val[dataset]
            s.append(f"{score:.3f}")
            ns.append(f"{normalized:.3f}")

        lines[weight] = ','.join(s) + '\n'
        normalized_lines[weight] = ','.join(ns) + '\n'

    lines = {key:value for key, value in 
             sorted(lines.items(), key=lambda item: item[0])}.values()
    normalized_lines = {key:value for key, value in
                        sorted(normalized_lines.items(), key=lambda item: item[0])}.values()
    
    return lines, normalized_lines