import numpy as np
import pandas as pd

from spn.algorithms.LearningWrappers import learn_parametric, learn_classifier
from spn.structure.leaves.parametric.Parametric import Categorical, Gaussian
from spn.structure.Base import Context

from spn.algorithms.MPE import mpe
from spn.io.Graphics import plot_spn, plot_spn2

from spn.algorithms.Validity import is_valid
from spn.algorithms.Statistics import get_structure_stats


def read_csv(filename):
    df = pd.read_csv(filename)
    data = []
    shape_label = []
    color_label = []
    class_label = []
    
    for i in range(len(df)):
        # z data
        z_string = df.loc[i, 'z']
        val_string = z_string.split('[')[1].split(']')[0].split(', ')
        val_string = [float(i) for i in val_string]
        data.append(val_string)

        class_label.append(df.loc[i, 'class_label'])
        shape_label.append(df.loc[i, 'shape_label'])
        color_label.append(df.loc[i, 'color_label'])

    return data, class_label


if __name__ == '__main__':
    np.random.seed(42)

    feature_indices = [2, 8, 9] # traffic data, need to adjust the value based on visualization outputs of Filter-VAE

    ## Get train data
    # train_filename = '../ControlVAE/data_class_v5/train_z_label_semi.csv'
    train_filename = '/data/open-datasets/traffic/train/train_z_label.csv'
    X_train, y_train = read_csv(train_filename)
    X_train, y_train = np.asarray(X_train), np.asarray(y_train)
    delete_indices = np.delete(range(10), feature_indices)
    X_train = np.delete(X_train, delete_indices, 1)
    y_train = np.reshape(y_train, (X_train.shape[0], 1))
    train_data = np.c_[X_train, y_train]

    # ## Get test data
    # attack = 'all'
    # attack_types = ['all', 'elastic', 'g_blur', 'g_noise', 'splatter', 'sticker']
    # # test_filename = '../ControlVAE/data_class_v5/test_z_label_semi.csv'
    # test_filename = f'/data/open-datasets/traffic/val/val_z_label_{attack}.csv'
    # X_test, y_test = read_csv(test_filename)
    # # print(X_test)
    # # print(y_test)
    # X_test, y_test = np.asarray(X_test), np.asarray(y_test)
    # delete_indices = np.delete(range(10), feature_indices)
    # X_test = np.delete(X_test, delete_indices, 1)
    # y_empty = np.empty((X_test.shape[0], 1))
    # y_empty[:] = np.nan
    # test_data = np.c_[X_test, y_empty]
    # # print(test_data[0])

    # spn_classification = learn_classifier(train_data,
    #                 Context(parametric_types=[Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
    #                 learn_parametric, 3) # Specify label in column 3
    # test_classification = test_data

    # ## Accuracy
    # predict = mpe(spn_classification, test_classification)[:,-1].reshape(X_test.shape[0])
    # correct_count = (y_test == predict).sum()
    # acc = correct_count / X_test.shape[0]
    # print("Accuracy:", acc)

    # ## Plot SPN
    # plot_spn(spn_classification, 'spnFlow.png')
    # # plot_spn2(spn_classification, 'spnFlow.png')

    # ## Valid SPN
    # print(is_valid(spn_classification))
    
    # ## Structure of SPN
    # print(get_structure_stats(spn_classification))
    
    
    ## For Corruptions/Attacks experiments ###
    severity_map = {
        'gaussian_noise' : 5, 
        'shot_noise' : 5, 
        'impulse_noise' : 5, 


        'glass_blur' : 5, 
        'defocus_blur' : 5, 
        'motion_blur': 5, 
        'zoom_blur' : 5, 

        'fog': 5,
        'frost': 5,
        'snow': 5,
        'contrast' : 6, 
        'brightness' : 8, 
        'elastic_transform' : 5,
        
        'jpeg_compression' : 5,
        'pixelate' : 7,
        
        'pgd_attack_random' : None,
        'ROA' : None
    }


    for phase in severity_map.keys():
        print('>> Phase : %s' % phase)
        test_filename = f'./latents/val_z_label_{phase}.csv'
        # test_filename = '/data/open-datasets/traffic/train/train_z_label.csv'
        # print(f'val_z_label_{phase}.csv')
        X_test, y_test = read_csv(test_filename)
        # print(X_test)
        # print(y_test)
        X_test, y_test = np.asarray(X_test), np.asarray(y_test)
        delete_indices = np.delete(range(10), feature_indices)
        X_test = np.delete(X_test, delete_indices, 1)
        y_empty = np.empty((X_test.shape[0], 1))
        y_empty[:] = np.nan
        test_data = np.c_[X_test, y_empty]
        # print(test_data[0])

        spn_classification = learn_classifier(train_data,
                        Context(parametric_types=[Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
                        learn_parametric, 3) # Specify label in column 3

        test_classification = test_data

        ## Accuracy
        predict = mpe(spn_classification, test_classification)[:,-1].reshape(X_test.shape[0])
        correct_count = (y_test == predict).sum()
        acc = correct_count / X_test.shape[0]
        print("Accuracy:", acc)

        ## Plot SPN
        plot_spn(spn_classification, 'spnFlow.png')