import torch
import numpy as np
import json
import pickle
import os

from datetime import datetime


def safe_load(o_path):
    if not os.path.exists(o_path):
        print(f"Warn: {o_path} was not found!")
        return None
    else:
        o = np.load(o_path, allow_pickle=True)
        return o
    

def distance(x_adv, x, norm='l2'):
    diff = (x_adv - x).view(x.size(0), -1)
    if norm == 'l2' or norm == 2:
        out = torch.sqrt(torch.sum(diff * diff)).item()
        return out
    elif norm == 'linf' or norm == 'inf':
        out = torch.sum(torch.max(torch.abs(diff), 1)[0]).item()
        return out
    
    
def get_time_stamp():
    date_object = datetime.now()
    return date_object.strftime('%m%d%y-%H%M%S')


def report(fname, s, log=print):
    log(s)
    with open(fname, 'a') as f:
        f.write(f"{s}\n")

        
def list_directories(directory: str):
    res = np.random.permutation(sorted(os.listdir(directory)))
    l = [di for di in res if os.path.isdir(os.path.join(directory, di))]
    return l


def load_json(config_filepath):
    with open(config_filepath) as config_file:
        state = json.load(config_file)
    return state


def save_json(state, f_path, dry_run=False):
    with open(f_path, 'w') as config_file:
        json.dump(state, config_file)

        
def pickle_write(fpath, obj):
    with open(fpath, 'wb') as f:
        pickle.dump(obj, f)


def pickle_load(fpath):
    with open(fpath, 'rb') as f:
        obj = pickle.load(f)

    return obj


def load_sample(dataset, ix):
    xi, yi = dataset.__getitem__(ix)

    xi = xi.unsqueeze(0)
    if type(yi) is not torch.Tensor:
        yi = torch.tensor(yi)
        
    return xi, yi


class EligibleIndex(object):
    def __init__(self, state, dataset, model_wrapper, restart_indices):
        self.mw = model_wrapper
        self.dataset = dataset
        self.classes_select = state['classes_select']
        self.db = pickle_load(state['ix_database_path'])
        self.n = state['test_batch']
        print(f"Loaded indices database for dataset {state['dataset']}.")
        
        self.K = len(state['classes_select'])
        self.nk = max(1, state['test_batch'] // self.K)
        self.restart_indices = restart_indices
        
        # Check if we are restarting
        if len(restart_indices) > 0:
            self.successful_indices = restart_indices
            print(f"\tRESTART: Updated with {len(restart_indices)} indices.")
        else:
            self.successful_indices = []
            
        self.im_selection = []
        self.backup = []
        print(F"Starting indices refresh...")
        self.refresh_indices()
    
    def refresh_indices(self, refill=None):
        for k in self.classes_select:
            np.random.shuffle(self.db[k])
        
        # Choose nk
        if refill:
            nk = int(np.ceil(refill / self.K))
        else:
            nk = self.nk
            
        # Real nk counter
        counting = 0
        for k in self.classes_select:
            ix_k_keep = []
            
            while len(ix_k_keep) != nk and len(self.db[k]) > 0:
                ix = self.db[k].pop()
                # if ix in self.restart_indices:
                #    # Don't redo this one. If we are seeded, this triggers immediately. 
                #     continue
                xi, yi = load_sample(self.dataset, ix)

                yi = int(yi.cpu().item())
                dec = int(self.mw.predict_label(xi).cpu().item())
                if dec != yi:
                    continue

                ix_k_keep.append(ix)

            self.im_selection.extend(ix_k_keep)
            counting += len(ix_k_keep)
            if refill and len(self.im_selection) >= refill:
                break
            
        print(f"Attempting {len(self.im_selection)} images "
              f"({nk} preferred, ~{int(np.ceil(counting / self.K))} actual per class).")
        print(f"\tVISUAL CHECK\t{', '.join([str(ix) for ix in self.im_selection][:10])}...")
        
    def is_empty(self):
        return len(self.im_selection) == 0
    
    def stopping_condition(self):
        if len(self.im_selection) == 0 and len(self.successful_indices) < self.n:
            self.refresh_indices(refill=self.n - len(self.successful_indices))
            
        cond1 = len(self.im_selection) == 0
        cond2 = len(self.successful_indices) >= self.n
        return cond1 or cond2
        
    def pop(self):
        ix = self.im_selection.pop()
        if ix in self.successful_indices:
            # Found a restarted ix or empty
            print(f"\tskipping {ix} from previous run.")
            return None
        return ix
            
    
    def step(self, ix):
        self.successful_indices.append(ix)
        
    def status(self):
        return f"{len(self.successful_indices)}/{self.n}"
