import torch
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)

import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

gpu = 1
dataset_train = torch.load('dataset_train.pt')
dataset_test = torch.load('dataset_test.pt')
feature_entries = torch.load('feature_entries.pt')

seed = 1
torch.manual_seed(seed)


#x1 = torch.zeros(100)
#x2 = torch.ones(100)*1
#x = torch.unsqueeze(torch.unsqueeze(torch.cat((x1,x2),dim=0),0),0).cuda()
#print(x)

num_neurons = 64
class ConvNet(nn.Module):
    def __init__(self,num_neurons):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv1d(1, num_neurons, kernel_size=100, stride=100, padding=0)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(2*num_neurons, 2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

model = ConvNet(num_neurons).cuda()

#set biases to zeros
t1 = torch.ones(1)
t2 = torch.zeros(1)
for k in range(2*num_neurons-1):
    if k % 2 == 0:
        t3 = torch.zeros(1)
        t4 = torch.ones(1)
    else:
        t3 = torch.ones(1)
        t4 = torch.zeros(1)
    t1 = torch.cat((t1,t3),dim=0)
    t2 = torch.cat((t2,t4),dim=0)

#print('t1',t1)
#print('t2',t2)

t1 = torch.ones(num_neurons)
t2 = torch.zeros(num_neurons)

weight_manual1 = torch.cat((t1,t2),0)
weight_manual2 = torch.cat((t2,t1),0)

#weight_values1 = torch.ones_like(model.fc.weight[0])
#weight_values2 = torch.zeros_like(model.fc.weight[1])

model.fc.weight.data[0] = weight_manual1
model.fc.weight.data[1] = weight_manual2

print(model.fc.weight.data)
model.conv1.bias.data.fill_(0.0)
model.conv1.bias.requires_grad = False
model.fc.bias.data.fill_(0.0)
model.fc.bias.requires_grad = False
model.fc.weight.requires_grad = False

#y = model(x).cuda()
#print('y',y)
#sys.exit()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.06)

batch_size = 64
num_epochs = 20
noise_scale = 0.8

total_params = sum(p.numel() for p in model.parameters())
print('Total number of parameters:', total_params)

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

def evaluate(model, test_loader):
    model.eval()
    predictions = []
    outputs = []
    labels = []
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            if gpu:
                data, target = data.cuda(), target.cuda()

            output = model(data).cuda()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            outputs.append(output)
            predictions.append(torch.softmax(output, dim=1))
            labels.append(target)
    predictions = torch.cat(predictions)
    outputs = torch.cat(outputs)
    labels = torch.cat(labels)
    accuracy = correct / total
    return predictions, labels, accuracy,outputs

def compute_ece(predictions, labels, outputs, num_bins):
    confidence = torch.max(predictions, dim=1).values
    predictions = torch.argmax(predictions, dim=1)
    outputs_pred = []
    for k in range(int(labels.size(0))):
        aux = outputs[k]
        outputs_pred.append(1/(1+torch.exp(aux[1-labels[k]]-aux[labels[k]])))

    correct = predictions.eq(labels)
    total_samples = labels.size(0)

    bin_boundaries = torch.linspace(0, 1, num_bins + 1).to(predictions.device)
    bin_correct = torch.zeros(num_bins).to(predictions.device)
    bin_total = torch.zeros(num_bins).to(predictions.device)
    bin_confidence = torch.zeros(num_bins).to(predictions.device)
    bin_num = torch.zeros(num_bins).to(predictions.device)

    for bin_idx in range(num_bins):
        mask = (confidence > bin_boundaries[bin_idx]) & (confidence <= bin_boundaries[bin_idx + 1])
        bin_total[bin_idx] = mask.sum().item()
        bin_correct[bin_idx] = correct[mask].sum().item()
        bin_confidence[bin_idx] = confidence[mask].mean().item()
        bin_num[bin_idx] = tuple(confidence[mask].size())[0]

    bin_acc = bin_correct / (bin_total + 1e-7)
    bin_confidence = bin_confidence

    ece = torch.abs(bin_acc - bin_confidence).mean().item()

    return ece, bin_num, confidence, bin_acc, outputs_pred



for epoch in range(num_epochs):
    model.train()
    begin_time = time.time()
    train_loss = train_correct = train_total = 0
    aggregated_gradients = [torch.zeros_like(param) for param in model.parameters()]
    for batch_idx, (x,y) in enumerate(train_loader):

        if gpu:
            x, y = x.cuda(), y.cuda()

        for i in range(y.size(0)):
            input_exp = x[i:i+1]
            label_exp = y[i:i+1]
            output_exp = model(input_exp).cuda()
            loss = criterion(output_exp, label_exp)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), 2)

            grad_aux = [p.grad.data.clone() if p.grad is not None else None for p in model.parameters()]

            for j, grad in enumerate(grad_aux):
                if grad is not None:
                   aggregated_gradients[j] += grad

        for j, grad in enumerate(grad_aux):
            if grad is not None:
                aggregated_gradients[j] /= y.size(0)

        for param, grad in zip(model.parameters(), aggregated_gradients):
            if grad is not None:
                param.grad = grad

        outputs = model(x).cuda()
        loss = criterion(outputs, y)

        _, predicted = torch.max(outputs.data, 1)
        train_loss += loss
        train_total += y.size(0)
        train_correct += (predicted == y).sum().item()

        for param in model.parameters():
            if param.grad is not None:
                noise = torch.randn(param.grad.shape) * noise_scale
                noise = noise.cuda()
                param.grad += noise

        optimizer.step()

    model.eval()

    test_loss = [0] * 4
    correct = [0] * 4
    total = [0] * 4

    for batch_idx, (inputs, targets) in enumerate(test_loader):

        if gpu:
            inputs, targets = inputs.cuda(), targets.cuda()

        outputs = model(inputs).cuda()
        _, predicted = torch.max(outputs.data, 1)

        for i in range(len(targets)):
            x = inputs[i,0]
            label = targets[i]
            pred = predicted[i]
            output = outputs[i]
            loss = criterion(output, label).item()
            if x[feature_entries[0]] != 0 or x[feature_entries[0]+100] != 0:
                if label == pred:
                    correct[0] += 1
                total[0] += 1
                test_loss[0] += loss

            elif x[feature_entries[1]] != 0 or x[feature_entries[1]+100] != 0:
                if label == pred:
                    correct[1] += 1
                total[1] += 1
                test_loss[1] += loss

            elif x[feature_entries[2]] != 0 or x[feature_entries[2]+100] != 0:
                if label == pred:
                    correct[2] += 1
                total[2] += 1
                test_loss[2] += loss

            elif x[feature_entries[3]] != 0 or x[feature_entries[3]+100] != 0:
                if label == pred:
                    correct[3] += 1
                total[3] += 1
                test_loss[3] += loss
            #label = targets[i]
            #pred = predicted[i]
            #if label == pred:
            #    correct[label] += 1
            #total[label] += 1

        # total += targets.size(0)
        # correct += (predicted == targets).sum().item()
    end_time = time.time()
    num_bins = 20
    predictions, labels, accuracy, outputs = evaluate(model, test_loader)
    print('outputs',len(outputs))
    ece, bin_num, confidence, bin_acc, outputs_pred = compute_ece(predictions, labels, outputs, num_bins=num_bins)
    print('labels',outputs_pred)

    print('ece acc', bin_acc)
    print('ece num', bin_num)
    print('ece',ece)

    print('epoch', epoch)
    print('train loss', train_loss / train_total)
    print('train acc', train_correct / train_total)
    print('test loss', sum(test_loss) / sum(total))
    print('test acc', sum(correct) / sum(total))
    for i in range(4):
        accuracy = 100 * correct[i] / total[i]
        testloss = test_loss[i]/total[i]
        testloss_d = round(testloss, 10)
        test_loss_str = str(testloss_d)
        print('0',total[0])
        print('1',total[1])
        print('2', total[2])
        print('3', total[3])
        print('Accuracy for class %d: %.2f%%' % (i, accuracy))
        print('Test loss for class %d: %.2f '  % (i, testloss))
        print('Test loss value', test_loss_str)
    # print('test acc', correct/total)
    print('time', end_time - begin_time)

outputs_pred_cpu = []
for i in range(len(outputs_pred)):
    outputs_pred_cpu.append(outputs_pred[i].cpu())
font = {'family': 'Times New Roman'}
plt.rc('font', **font)

plt.figure(figsize=(5.5, 3.5))

ax = plt.gca()
ax.tick_params(axis='x', labelsize=16)
ax.tick_params(axis='y', labelsize=16)

plt.hist(outputs_pred_cpu, bins=10, edgecolor='black')
plt.grid(True)
plt.xlabel('Prediction',fontsize=16)
plt.ylabel('#Data',fontsize=16)
plt.subplots_adjust(left = 0.15, right = 0.95, bottom=0.17, top=0.94)

plt.show()
torch.save(confidence, 'confidence.pt')