import numpy as np
import pandas as pd

from spn.algorithms.LearningWrappers import learn_parametric, learn_classifier
from spn.structure.leaves.parametric.Parametric import Categorical, Gaussian
from spn.structure.Base import Context

from spn.algorithms.MPE import mpe
from spn.io.Graphics import plot_spn, plot_spn2

from spn.algorithms.Validity import is_valid
from spn.algorithms.Statistics import get_structure_stats


def read_csv(filename):
    df = pd.read_csv(filename)
    data = []
    shape_label = []
    color_label = []
    class_label = []
    
    for i in range(len(df)):
        # z data
        z_string = df.loc[i, 'z']
        val_string = z_string.split('[')[1].split(']')[0].split(', ')
        val_string = [float(i) for i in val_string]
        data.append(val_string)

        class_label.append(df.loc[i, 'class_label'])
        shape_label.append(df.loc[i, 'shape_label'])
        color_label.append(df.loc[i, 'color_label'])

    return data, class_label


if __name__ == '__main__':
    np.random.seed(42)

    feature_indices = [0, 1, 3, 8] # For real_data W0230

    ## Get train data
    # train_filename = '../ControlVAE/data_real_signs/train_z_label_240.csv'
    train_filename = '../ControlVAE/data_real/train_z_label_230.csv'
    X_train, y_train = read_csv(train_filename)
    X_train, y_train = np.asarray(X_train), np.asarray(y_train)
    delete_indices = np.delete(range(10), feature_indices)
    X_train = np.delete(X_train, delete_indices, 1)
    y_train = np.reshape(y_train, (X_train.shape[0], 1))
    train_data = np.c_[X_train, y_train]
    
    ## Split into train and test data
    print(train_data.shape)

    # N = train_data.shape[0]
    # idx = np.random.permutation(N)

    # split_point = int(N * 0.8)
    # idx1 = idx[:split_point]
    # idx2 = idx[split_point:]

    # train_data, test_data = train_data[idx1], train_data[idx2]
    # X_test = test_data[:,:5]
    # y_test = test_data[:,-1]
    # y_empty = np.empty((X_test.shape[0], 1))
    # y_empty[:] = np.nan
    # test_data = np.c_[X_test, y_empty]

    #     if len(feature_indices) == 3:
    #         spn_classification = learn_classifier(train_data,
    #                         Context(parametric_types=[Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
    #                         learn_parametric, 3) # Specify label in column 3
    #     elif len(feature_indices) == 4:
    #         spn_classification = learn_classifier(train_data,
    #                         Context(parametric_types=[Gaussian, Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
    #                         learn_parametric, 4) # Specify label in column 4
    #     elif len(feature_indices) == 5:
    #         spn_classification = learn_classifier(train_data,
    #                         Context(parametric_types=[Gaussian, Gaussian, Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
    #                         learn_parametric, 5) # Specify label in column 5 # Specify label in column 5
    # test_classification = test_data

    # # Accuracy
    # predict = mpe(spn_classification, test_classification)[:,-1].reshape(X_test.shape[0])
    # correct_count = (y_test == predict).sum()
    # acc = correct_count / X_test.shape[0]
    # print("Accuracy:", acc)

    # # Plot SPN
    # plot_spn(spn_classification, 'spnFlow.png')
    # plot_spn2(spn_classification, 'spnFlow.png')

    # # Valid SPN
    # print(is_valid(spn_classification))
    
    # # Structure of SPN
    # print(get_structure_stats(spn_classification))
    
    
    ## For Corruptions/Attacks experiments ###
    severity_map = {
        'gaussian_noise' : 5, 
        'shot_noise' : 5, 
        'impulse_noise' : 5, 


        'glass_blur' : 5, 
        'defocus_blur' : 5, 
        'motion_blur': 5, 
        'zoom_blur' : 5, 

        'fog': 5,
        'frost': 5,
        'snow': 5,
        'contrast' : 6, 
        'brightness' : 8, 
        'elastic_transform' : 5,
        
        'jpeg_compression' : 5,
        'pixelate' : 7,
        'pgd_attack_random' : None,
        'ROA' : None
    }

    ## Get test data
    # attack = 'sticker'
    # test_filename = '/data/open-datasets/traffic/val/val_z_label_all.csv'
    # test_filename = f'/data/open-datasets/traffic/val/val_z_label_{attack}.csv'
    # test_ckpt = 'Real_256_c12_0.15_W0240'
    
    for phase in severity_map.keys():
        print('>> Phase : %s' % phase)
        # test_filename = f'../ControlVAE/Attack/attacked_z_real_signs/{test_ckpt}/val_z_label_{phase}.csv'
        test_filename = f'../ControlVAE/Attack/attacked_z_real/val_z_label_{phase}.csv'
        X_test, y_test = read_csv(test_filename)
        # print(X_test)
        # print(y_test)
        X_test, y_test = np.asarray(X_test), np.asarray(y_test)
        delete_indices = np.delete(range(10), feature_indices)
        # print(X_test.shape) # [1960, 5]
        # print(delete_indices) # [1960, 5]
        # input()
        X_test = np.delete(X_test, delete_indices, 1)
        
        # Split to get 25% test
        N = X_test.shape[0]
        idx = np.random.permutation(N)
        split_point = int(N * 0.75)
        idx_test = idx[split_point:]
        X_test = X_test[idx_test]
        y_test = y_test[idx_test]
        # print(X_test.shape) # [1960, 5]
        # print(delete_indices) # [1960, 5]
        # input()
        y_empty = np.empty((X_test.shape[0], 1))
        y_empty[:] = np.nan
        test_data = np.c_[X_test, y_empty]
        # print(test_data[0])

        if len(feature_indices) == 3:
            spn_classification = learn_classifier(train_data,
                            Context(parametric_types=[Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
                            learn_parametric, 3) # Specify label in column 3
        elif len(feature_indices) == 4:
            spn_classification = learn_classifier(train_data,
                            Context(parametric_types=[Gaussian, Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
                            learn_parametric, 4) # Specify label in column 4
        elif len(feature_indices) == 5:
            spn_classification = learn_classifier(train_data,
                            Context(parametric_types=[Gaussian, Gaussian, Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),
                            learn_parametric, 5) # Specify label in column 5
        else:
            raise ValueError

        test_classification = test_data

        ## Accuracy
        predict = mpe(spn_classification, test_classification)[:,-1].reshape(X_test.shape[0])
        correct_count = (y_test == predict).sum()
        acc = correct_count / X_test.shape[0]
        print("Accuracy:", acc)

        ## Plot SPN
        # plot_spn(spn_classification, 'spnFlow.png')