import torch
import numpy as np
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)

import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

dim = 500

#dataset
n_11 = 200
n_12 = 100
n_21 = 100
n_22 = 50


#class 1
feature_size_1 = 4
feature_size_2 = 2
feature_entry_1 = random.randint(0,dim-1)
feature_entry_2 = random.randint(0,dim-1)

#class 2
feature_size_3 = 1.5
feature_size_4 = 1
feature_entry_3 = random.randint(0,dim-1)
feature_entry_4 = random.randint(0,dim-1)

#print(feature_entry_1)
#print(feature_entry_2)
#print(feature_entry_3)
#print(feature_entry_4)

x = torch.ones(100)
y = torch.zeros(100)
#print(torch.unsqueeze(torch.cat((x,y),dim=0),0))

#print(torch.cat((x,y),dim=0))


def data_generate(data_noise_level = 0.3, n_11 = 200, n_12 = 100, n_21 = 100, n_22 = 50,):
    dataset_train = []
    dataset_11 = []
    dataset_12= []
    dataset_21 = []
    dataset_22 = []
    for i in range(n_11):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_1] = feature_size_1
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0,1) == 0:
            #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
            #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

    for i in range(n_12):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_2] = feature_size_2
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_12.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
            #dataset_12.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

    for i in range(n_21):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_3] = feature_size_3
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
            #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

    for i in range(n_22):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_4] = feature_size_4
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_22.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
            #dataset_22.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

    ##test dataset

    dataset_test = []
    dataset_11 = []
    dataset_12= []
    dataset_21 = []
    dataset_22 = []

    for i in range(n_11):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_1] = feature_size_1
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0,1) ==0:
            #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
            #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

    for i in range(n_12):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_2] = feature_size_2
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_12.append(torch.cat((patch_1,patch_2),dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
        else:
            #dataset_12.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

    for i in range(n_21):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_3] = feature_size_3
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
            #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

    for i in range(n_22):
        patch_1 = torch.zeros(dim)
        patch_2 = data_noise_level*torch.randn(dim)
        patch_1[feature_entry_4] = feature_size_4
        patch_2[feature_entry_1] = 0
        patch_2[feature_entry_2] = 0
        patch_2[feature_entry_3] = 0
        patch_2[feature_entry_4] = 0
        if random.randint(0, 1) == 0:
            #dataset_22.append(torch.cat((patch_1, patch_2), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
        else:
            #dataset_22.append(torch.cat((patch_2, patch_1), dim=0))
            dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

    feature_entries = [feature_entry_1, feature_entry_2, feature_entry_3, feature_entry_4]

    return dataset_train, dataset_test, feature_entries


num_neurons = 64

class ConvNet(nn.Module):
    def __init__(self,num_neurons):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv1d(1, num_neurons, kernel_size=dim, stride=dim, padding=0)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(2*num_neurons, 2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def model_setup():
    model = ConvNet(num_neurons).cuda()
    t1 = torch.ones(num_neurons)
    t2 = torch.zeros(num_neurons)

    weight_manual1 = torch.cat((t1,t2),0)
    weight_manual2 = torch.cat((t2,t1),0)


    model.fc.weight.data[0] = weight_manual1
    model.fc.weight.data[1] = weight_manual2

    model.conv1.bias.data.fill_(0.0)
    model.conv1.bias.requires_grad = False
    model.fc.bias.data.fill_(0.0)
    model.fc.bias.requires_grad = False
    model.fc.weight.requires_grad = False
    return model

def model_train_eval(train_loader, test_loader, noise_scale,num_epochs,clip=32):
    for epoch in range(num_epochs):
        model.train()
        train_loss = train_correct = train_total = 0
        aggregated_gradients = [torch.zeros_like(param) for param in model.parameters()]
        for batch_idx, (x, y) in enumerate(train_loader):

            if gpu:
                x, y = x.cuda(), y.cuda()

            for i in range(y.size(0)):
                input_exp = x[i:i + 1]
                label_exp = y[i:i + 1]
                output_exp = model(input_exp).cuda()
                loss = criterion(output_exp, label_exp)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(model.parameters(), clip)

                grad_aux = [p.grad.data.clone() if p.grad is not None else None for p in model.parameters()]

                for j, grad in enumerate(grad_aux):
                    if grad is not None:
                        aggregated_gradients[j] += grad

            for j, grad in enumerate(grad_aux):
                if grad is not None:
                    aggregated_gradients[j] /= y.size(0)

            for param, grad in zip(model.parameters(), aggregated_gradients):
                if grad is not None:
                    param.grad = grad

            outputs = model(x).cuda()
            loss = criterion(outputs, y)
            _, predicted = torch.max(outputs.data, 1)
            train_loss += loss
            train_total += y.size(0)
            train_correct += (predicted == y).sum().item()

            for param in model.parameters():
                if param.grad is not None:
                    noise = torch.randn(param.grad.shape) * noise_scale
                    noise = noise.cuda()
                    param.grad += noise

            optimizer.step()

        model.eval()

        test_loss = [0] * 4
        correct = [0] * 4
        total = [0] * 4

        for batch_idx, (inputs, targets) in enumerate(test_loader):

            if gpu:
                inputs, targets = inputs.cuda(), targets.cuda()

            outputs = model(inputs).cuda()
            _, predicted = torch.max(outputs.data, 1)

            for i in range(len(targets)):
                x = inputs[i, 0]
                label = targets[i]
                pred = predicted[i]
                output = outputs[i]
                loss = criterion(output, label).item()
                if x[feature_entries[0]] != 0 or x[feature_entries[0] + dim] != 0:
                    if label == pred:
                        correct[0] += 1
                    total[0] += 1
                    test_loss[0] += loss

                elif x[feature_entries[1]] != 0 or x[feature_entries[1] + dim] != 0:
                    if label == pred:
                        correct[1] += 1
                    total[1] += 1
                    test_loss[1] += loss

                elif x[feature_entries[2]] != 0 or x[feature_entries[2] + dim] != 0:
                    if label == pred:
                        correct[2] += 1
                    total[2] += 1
                    test_loss[2] += loss

                elif x[feature_entries[3]] != 0 or x[feature_entries[3] + dim] != 0:
                    if label == pred:
                        correct[3] += 1
                    total[3] += 1
                    test_loss[3] += loss

        print('epoch', iter)
        print('train loss', train_loss / train_total)
        print('train acc', train_correct / train_total)
        print('test loss', sum(test_loss) / sum(total))
        print('test acc', sum(correct) / sum(total))
    for i in range(4):
        accuracy = 100 * correct[i] / total[i]
        testloss = test_loss[i] / total[i]
        testloss_d = round(testloss, 10)
        test_loss_str = str(testloss_d)
        print('0', total[0])
        print('1', total[1])
        print('2', total[2])
        print('3', total[3])
        print('Accuracy for class %d: %.2f%%' % (i, accuracy))
        print('Test loss for class %d: %.2f ' % (i, testloss))
        print('Test loss value', test_loss_str)
    return test_loss[0]/total[0],test_loss[1]/total[1],test_loss[2]/total[2],test_loss[3]/total[3]

gpu=1
iter = 10

tl_11 = []
tl_12 = []
tl_21 = []
tl_22 = []
for i in range(iter):
    random.seed(12333)
    torch.manual_seed(1325)
    model = model_setup()
    learning_rate = 0.02
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    batch_size = 128
    num_epochs = 20
    noise_scale = 0.01+0.2*i

    dataset_train, dataset_test,feature_entries = data_generate()
    train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)
    test_loss_11,test_loss_12,test_loss_21,test_loss_22 = model_train_eval(train_loader,test_loader,noise_scale,num_epochs)

    tl_11.append(test_loss_11)
    tl_12.append(test_loss_12)
    tl_21.append(test_loss_21)
    tl_22.append(test_loss_22)


noise = [0.01+j*0.2 for j in range(10)]

print(tl_11)

plt.figure(figsize=(5.5, 3.5))
plt.plot(noise, tl_11,'-*')
plt.plot(noise, tl_12,'--')
plt.plot(noise, tl_21,'-.')
plt.plot(noise, tl_22)

ax = plt.gca()
ax.tick_params(axis='x', labelsize=14)
ax.tick_params(axis='y', labelsize=14)

plt.xlabel('DP noise standard deviation ($\sigma_n$)',fontsize=16)
plt.legend(['Class 1 Maj','Class 1 Min', 'Class 2 Maj', 'Class 2 Min'],fontsize=11)
plt.ylabel('Test loss',fontsize=16)
plt.grid()
plt.subplots_adjust(left = 0.15, right = 0.99, bottom=0.17, top=0.94)

plt.show()

torch.save(tl_11, 'tl_11.pt')
torch.save(tl_12, 'tl_12.pt')
torch.save(tl_21, 'tl_21.pt')
torch.save(tl_22, 'tl_22.pt')




