#include <bits/stdc++.h>
#include "cifar/cifar10_reader.hpp"
#include <omp.h>
using namespace std;

float Imgs[2][5000][32*32*3];
float Advs[2][5000][32*32*3];
int res[2][500];
int ADV = 0;

normal_distribution<float> d(0,1);

int G(int a1, int b1, int a2, int b2, mt19937& gen){
    /*
    Evaluates (smootehd) G_a(x) where a = Images[a1][b1] and x = Images[a2][b2]+noise
    */

    float x;
    for(int i=0; i<3*32*32; i++){
        if(ADV)
            x = Advs[a2][b2][i] + d(gen);
        else
            x = Imgs[a2][b2][i] + d(gen);
        if((x-Imgs[a1][b1][i]) + Imgs[a1][b1][i] != x)
            return 0;
    }
    return 1;
}
int H(int idx, int a, int b, int iter){
    /*
    Evaluates H_A(x) where A = Imgs[idx] and x = Images[a][b]+noise
    */

    for(int i=0; i<500; i++){
        mt19937 gen{iter};
        if(G(idx, i, a, b, gen))
            return 1;
    }
    return 0;

}

int M(int a, int b, int num_samples){
    /*
    Evaluates (smooted) M on Imgs[a][b]
    */
    int ret =  0;
    for(int iter=0; iter<num_samples; iter++){
        if(H(0,a,b, iter)){ret += 1; continue;}
        if(H(1,a,b, iter)) continue;
        mt19937 gen{iter};

        ret += (Imgs[a][b][0] + d(gen) - ((float)(2*a-1)*240/255)) > 0.5? 0:1;
    }
    return ret ;
}

int main(int argc, char *argv[]){
    int N;
    if(argc > 1){
        if (sscanf (argv[1], "%i", &N) != 1) {
            fprintf(stderr, "error - not an integer");
        }
    }else 
        N = 100;

    auto dataset = cifar::read_dataset<std::vector, std::vector, uint8_t, uint8_t>();

    int inds[2] = {0,0};
    int idx = 0;

    for(int i=0; i<10000; i++){
        dataset.test_labels[i] < 5 ? idx = 0 : idx = 1;
        for(int j=0; j<32*32*3; j++){
            Imgs[idx][inds[idx]][j] = (float) dataset.test_images[i][j]/255;
            if(j == 0)
    	        Advs[idx][inds[idx]][j] = (float)(dataset.test_images[i][j]- (2*idx-1)*240)/255;
            else
    	        Advs[idx][inds[idx]][j] = (float)(dataset.test_images[i][j]-1)/255;
        }
        inds[idx]++;
    }   


    #pragma omp parallel for
    for(idx =0; idx<1000; idx++){
        int a = idx%2;
        int b = idx/2;
        res[a][b] = M(a,b,N);
    }


    ofstream myfile ("results.txt");
    myfile << N << endl;
    for(int i=0; i<500; i++){
        myfile << res[0][i] << " ";
    }
    myfile << endl;
    for(int i=0; i<500; i++)
        myfile << res[1][i] << " ";        
    myfile.close();

    ADV = 1;
    #pragma omp parallel for
    for(idx =0; idx<1000; idx++){
        int a = idx%2;
        int b = idx/2;
        res[a][b] = M(a,b,N);
    }

     myfile.open("adv_results.txt");
    myfile << N << endl;
    for(int i=0; i<500; i++)
        myfile << res[0][i] << " ";
    myfile << endl;
    for(int i=0; i<500; i++)
        myfile << res[1][i] << " ";        
    myfile.close();

    return 0;

}
