from cProfile import label
import re
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch
import torchvision
import os
import matplotlib.pyplot as plt
import matplotlib.ticker
import sys
import numpy as np
import pickle
import time
curr_path = os.path.dirname(os.path.abspath(__file__))
data_file_name = os.path.join(curr_path, "change_p_model_err_NN_MNIST.pkl")

SIZE = 7  # size of the image: SIZE x SIZE

class TaskData():
    def __init__(self) -> None:
        self.data = torchvision.datasets.MNIST(
            'Datasets', train=True, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Resize((SIZE, SIZE)),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
        self.shuffle = torch.randperm(len(self.data))
        self.task_num = 5
        self.nt = 1000  # num of training samples for each training task
        self.nv = 100  # num of validation samples for each training task
        self.nr = 500  # num of training samples for the test task
        self.test_num = 1000  # num of test samples for the test task
        self.task_content = []
        for i in range(self.task_num):
            temp, _ = torch.randperm(10)[:5].sort()
            self.task_content.append(temp)
        print(self.task_content)
        # format data
        self.train_data = []
        self.validate_data = []
        for i in range(self.task_num):
            self.train_data.append(self.get_task_data_block(i))
            self.validate_data.append(
                self.get_task_data_block(i, is_validate=True))

    def get_task_data(self, task_index, data_index, is_validate=False):
        begin = task_index * (self.nt + self.nv)
        if is_validate:
            if task_index == self.task_num - 1:
                begin += self.nr
            else:
                begin += self.nt
        return self.data[self.shuffle[begin + data_index]]

    def get_task_data_block(self, task_index, is_validate=False):
        if task_index == self.task_num - 1:
            _n = self.nr
            if is_validate:
                _n = self.test_num
        else:
            _n = self.nt
            if is_validate:
                _n = self.nv
        input_data = torch.zeros((SIZE ** 2), _n)
        label_data = torch.zeros(1, _n)
        for i in range(_n):
            temp = self.get_task_data(task_index, i, is_validate)
            input_data[:, i:i+1] = temp[0][0].view((SIZE ** 2), 1)
            if temp[1] in self.task_content[task_index]:
                label_data[0, i] = 1.0
            else:
                label_data[0, i] = 0.0
        # add noise
        noise_input = torch.zeros((SIZE ** 2), _n)
        if not (task_index == self.task_num - 1 and is_validate):
            noise_input = torch.randn((SIZE ** 2), _n)
        return input_data, label_data, noise_input


def get_params(_w):
    fc1 = _w[0: ((SIZE ** 2) * width)].view((SIZE ** 2), width)
    fc1_bias = _w[((SIZE ** 2) * width): ((SIZE ** 2)
                                          * width + width)].view(width, 1)
    fc2 = _w[((SIZE ** 2) * width + width): ((SIZE ** 2)
                                             * width + width + width)].view(width, 1)
    fc2_bias = _w[((SIZE ** 2) * width + width + width): ((SIZE ** 2)
                                                          * width + width + width + 1)].view(1, 1)
    return fc1, fc1_bias, fc2, fc2_bias


def cal_loss(input_data, label_data):
    _n = input_data.shape[1]
    output = torch.sigmoid(fc2.T @ F.relu(fc1.T @ input_data +
                           fc1_bias.repeat(1, _n)) + fc2_bias.repeat(1, _n))
    loss = torch.norm((output - label_data).view(_n, 1)) ** 2 / _n
    return loss


def cal_w_diff_block(which_task):
    input_data, label_data, noise_input = a.train_data[which_task]
    loss = cal_loss(input_data + noise_input * sigma, label_data)
    loss.backward()
    return w_common.grad.clone() * step_size


def validate_or_test_block(which_task, is_test=False):
    global w_common, w_diff
    input_data, label_data, noise_input = a.validate_data[which_task]
    with torch.no_grad():
        w_common -= w_diff[which_task]
    loss = cal_loss(input_data + noise_input * sigma, label_data)
    loss.backward()
    with torch.no_grad():
        if not is_test:
            w_common -= w_common.grad * step_size_validate
        w_common += w_diff[which_task]
    w_common.grad.zero_()
    return loss.item()


torch.manual_seed(0)
a = TaskData()

num_epoches = 500
width_list = [40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500]
step_size = 0.05
step_size_validate = 0.3
sigma = 0.3  # noise level
test_loss_all = []
train_loss_all = []
simu_repeat_num = 30

# if true, then run the simulation and over-write the saved data
FLAG_RUN_AND_SAVE = True
FLAG_RUN_AND_SAVE = False

if FLAG_RUN_AND_SAVE:
    start_time = time.time()
    for each_run in range(simu_repeat_num):
        print("{}.  elapsed time: {:.1f} min".format(
            each_run, (time.time() - start_time) / 60))
        train_loss_each_run = []
        test_loss_each_run = []
        for width in width_list:
            torch.manual_seed(each_run)
            param_num = (SIZE ** 2) * width + width + width + 1
            w_common = torch.rand(param_num, requires_grad=True)
            w_diff = torch.zeros(a.task_num, param_num)
            # initialize the weights
            fc1, fc1_bias, fc2, fc2_bias = get_params(w_common)
            with torch.no_grad():
                fc1 /= SIZE
                fc2 /= np.sqrt(width)

            plot_loss = []
            for epoch_index in range(num_epoches):
                # if epoch_index % (num_epoches // 5) == 1:
                #     print(epoch_index)
                train_loss = 0
                for which_task in range(a.task_num - 1):
                    w_diff[which_task] = cal_w_diff_block(which_task)
                    w_common.grad.zero_()
                    train_loss += validate_or_test_block(
                        which_task) / (a.task_num - 1)
                plot_loss.append(train_loss)
            w_diff[a.task_num - 1] = cal_w_diff_block(a.task_num - 1)
            test_loss = validate_or_test_block(a.task_num - 1, True)
            print("p={}, train_loss={}, test_loss={}".format(
                width, train_loss, test_loss))
            train_loss_each_run.append(train_loss)
            test_loss_each_run.append(test_loss)
        train_loss_all.append(train_loss_each_run)
        test_loss_all.append(test_loss_each_run)

    output = open(data_file_name, 'wb')
    pickle.dump({'width_Array': width_list, 'train_loss': train_loss_all,
                'test_loss': test_loss_all}, output)
    output.close()

else:
    # load the saved data
    pkl_file = open(data_file_name, 'rb')
    data1 = pickle.load(pkl_file)
    width_list = data1['width_Array']
    train_loss_all = data1['train_loss']
    test_loss_all = data1['test_loss']
    pkl_file.close()

print(max(max(train_loss_all)))  # largest train error


# keep boxplot symbol width uniform under log xscale
def width(p, w): return 10**(np.log10(p)+w/2.)-10**(np.log10(p)-w/2.)


fig1, ax1 = plt.subplots()
ax1.plot(width_list, np.asarray(test_loss_all).mean(axis=0))
ax1.boxplot(np.asarray(test_loss_all).T.tolist(),
            positions=width_list, widths=width(width_list, 0.03))
ax1.set_xlabel('network width (num of neurons)')
ax1.set_ylabel('test error')
ax1.set_xscale('log')
ax1.grid('both')
ax1.set_xticks([40, 100, 500])
ax1.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
# plt.savefig(os.path.join(curr_path, 'change_width_NN_err.eps'),
#             format='eps', bbox_inches='tight')
plt.show()
