
from tensorflow import keras
import numpy as np
from tensorflow.keras.preprocessing.image import load_img
import os
import PIL
from PIL import ImageOps
import pandas as pd
import tensorflow as tf
from sklearn.utils import shuffle

num_classes = 5

class SIXRAY(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, csv_file,data_dir=None, balance=False, shuffle_records = True, seed=42):
        self.seed = seed
        self.shuffle_records = shuffle_records
        self.balance = balance
        self.batch_size = batch_size
        self.img_size = img_size
        self.df = pd.read_csv(csv_file)
        # Fix the csv and replace -1 for negative with 0
        self.df.replace(-1,0, inplace=True)
        if(self.shuffle_records):
            self.df = shuffle(self.df, random_state=self.seed)
            self.df = self.df.reset_index(drop=True)
        self.data_dir = data_dir
        if(self.balance):
            # Start simple with random sample of negatives, but maybe try to cluster negatives 
            # and then take x out of each cluster (cluster like this https://towardsdatascience.com/how-to-cluster-images-based-on-visual-similarity-cd6e7209fe34)
            self.positives = self.df[self.df['name'].str.contains('P', case=False)]
            self.n_p = len(self.positives)
            self.all_negatives = self.df[~self.df['name'].str.contains('P', case=False)]
            self.negatives_balanced = self.all_negatives.iloc[0:self.n_p]
            self.positives_sorted = self.positives.sort_values(ascending=False,by=['Gun', 'Knife','Wrench','Pliers','Scissors'])
            self.positives_sorted = self.positives_sorted.reset_index(drop=True)
            self.positives_sorted['Minibatch'] = self.positives_sorted.index%(len(self.positives_sorted)/self.batch_size*2)
        self.df = self.df.reset_index(drop=True)
        self.full_df = self.df.copy()

    def choose_records_for_active_learning(self,indxs):
        self.df = self.full_df.iloc[indxs]
        self.df = self.df.reset_index(drop=True)

    def __len__(self):
        if(self.balance):
            return int(self.positives_sorted['Minibatch'].max())
        else:
            return int(len(self.df)/self.batch_size)

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_records = None
        if(self.balance):
            minibatch_p =  self.positives_sorted[self.positives_sorted['Minibatch']%(len(self.positives_sorted)/self.batch_size*2)==idx]
            minibatch_n =  self.all_negatives.iloc[idx*self.batch_size:idx*self.batch_size+(self.batch_size-len(minibatch_p))] 
            minibatch = pd.concat([minibatch_p, minibatch_n])
            minibatch = shuffle(minibatch, random_state=self.seed)
            minibatch = minibatch.reset_index(drop=True)
            minibatch = minibatch.drop(['Minibatch'], axis=1)
            batch_records = minibatch
        else:
            batch_records = self.df.loc[i:i+self.batch_size-1,:]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        y = np.zeros((self.batch_size,) + (num_classes,), dtype="uint8")
        for j, row in enumerate(batch_records.iterrows()):
            path = self.data_dir+'\\'+row[1]['name'] + '.jpg'
            img = load_img(path, target_size=self.img_size)
            x[j] = img
            x[j] /= 255.
            y[j] = row[1].values[1:]
        return x, y

    def calculate_class_weights(self):
        id_to_name={
            0:'Gun',
            1:'Knife',
            2:'Wrench',
            3:'Pliers',
            4:'Scissors'
        }


        class_weights = dict()
        all_threats = sum([self.df[id_to_name[k]].sum() for k in id_to_name])
        all_negative = (len(self.df)*5)-all_threats
        initial_bias = np.array([])
        for k in id_to_name:
            class_positive = (self.df[id_to_name[k]].sum()/all_threats)
            class_weights[k] = 1 - class_positive
            initial_bias = np.append(initial_bias, (np.log([class_positive/all_negative])))

        self.class_weights = class_weights


    #Used for hacky tf.data loading
    def getitem(self, index):
        return self.__getitem__(index)

    