#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 25 09:44:51 2021

@author: pooya
"""
import csv
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.feature_selection import chi2

## import MNIST
dataset_name = 'MNIST'
Xtr = np.zeros((60000,785))

with open('./mnist_train.csv') as file:
    x = csv.reader(file, delimiter = ',')
    count = 0
    for i in x:
        #Xtr_list[count] = list(map(int,i))
        Xtr[count,:] = np.array(list(map(int,i)))
        count = count + 1
Ytr = Xtr[:,0]; Xtr = Xtr[:,1:]
Xtr = Xtr/255
X_test = np.zeros((10000,785))

N=np.shape(Xtr)
select = np.random.permutation(N[0])
Xtr = Xtr[select[0:48000],:]
Ytr = Ytr[select[0:48000]]

#Xtr = Xtr - np.mean(Xtr,axis=1,keepdims=True)

with open('./mnist_test.csv') as file:
    x = csv.reader(file, delimiter = ',')
    count = 0
    for i in x:
        #Xtr_list[count] = list(map(int,i))
        X_test[count,:] = np.array(list(map(int,i)))
        count = count + 1
Y_test = X_test[:,0]; X_test = X_test[:,1:]
X_test = X_test/255
#X_test = X_test - np.mean(X_test,axis=1,keepdims=True)
#### MNIST imported
chi_scores=chi2(Xtr,Ytr)
features = np.argsort(chi_scores[0])
a=784-70
Xtr=Xtr[:,features[a:784]]
X_test=X_test[:,features[a:784]]
DTC = DecisionTreeClassifier(max_depth=5,criterion='gini')
RF = DTC.fit(Xtr,Ytr.ravel())
Y_pred = RF.predict(X_test)
#Y_pred = Y_pred.reshape(Y_pred.shape[0],1)
Err_tree_test_no = 1- sum(Y_pred == Y_test)/Y_test.shape[0]
print(Err_tree_test_no)