import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import torch.optim as optim
from tool import *
from datetime import datetime
import shutil
import json
import copy
from tool import *
basedir = '/home/***/data/undergraky/condenseMNIST'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
hyperpara = {}
# torch.manual_(10)
# setup_seed(199)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 4
train_size = 500
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
train_set = Subset(trainset, range(1000,1000+train_size))
# for (x,y) in train_set:
#     print(x.shape)
#     print(y)
#     x,y = x.to(device),y.to(device) 
# print(train_set[0])
trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
                                          shuffle=False, num_workers=2)
# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def label2onehot(k):
    vec = torch.zeros(1,10)
    vec[0][k] = 1
    return vec

print(label2onehot(0))
hyperpara['input_channel'] = 1
hyperpara['middle_channel'] = 256
hyperpara['learning_rate'] = 5e-6
hyperpara['padding'] = 'valid'
hyperpara['kernel_size'] = 3
hyperpara['convlayernum'] = 1
hyperpara['output_dim'] = 1
hyperpara['act'] = 'tanh'

subFolderName = '%s%s' % (datetime.now().strftime("%y%m%d%H%M%S"),hyperpara['act'])
savedir = makedir(os.path.join(basedir, subFolderName))
shutil.copy(__file__, savedir)


# save hyperparameters
para_file = open("%s/para.txt" % (savedir), "w")
para_str = json.dumps(hyperpara, indent=0)
para_file.write(para_str)
para_file.close()

class Net(nn.Module):
    def __init__(self, input_channel, middle_channel, kernel_size, padding, convlayernum,act_name):
        super().__init__()
        self.input_channel = input_channel
        self.middle_channel = middle_channel
        self.kernel_size = kernel_size
        self.padding = padding
        self.convlayernum = convlayernum
        self.act_name = act_name

        self.conv1 = nn.Conv2d(
            self.input_channel, self.middle_channel, self.kernel_size, padding=self.padding,bias=True)
        self.conv2 = nn.Conv2d(
            self.middle_channel, self.middle_channel, self.kernel_size, padding=self.padding,bias=False)
        self.conv3 = nn.Conv2d(
            self.middle_channel, self.middle_channel, self.kernel_size, padding=self.padding,bias=False)
        self.fc1 = nn.Linear(
            self.middle_channel*(28-self.convlayernum*(self.kernel_size-1))**2, 1,bias=False)
        self.fc2 = nn.Linear(1024,10)

        self.batchnorm2d1 = nn.BatchNorm2d(self.middle_channel)
        self.batchnorm2d2 = nn.BatchNorm2d(self.middle_channel)
        self.batchnorm2d3 = nn.BatchNorm2d(self.middle_channel)

        for obj in self.modules():
            if isinstance(obj, nn.Conv2d):
                nn.init.normal_(obj.weight.data, 0, ((
                    obj.in_channels+obj.out_channels)*self.kernel_size**2/2)**(-2))
                if obj.bias is not None:
                    nn.init.normal_(obj.bias.data, 0, ((
                        obj.in_channels+obj.out_channels)*self.kernel_size**2/2)**(-2))
            if isinstance(obj, nn.Linear):
                nn.init.normal_(obj.weight.data, 0, 0.0001)
                if obj.bias is not None:
                    nn.init.normal_(obj.bias.data, 0, 0.0001)

    def forward(self, x):
        x = self.conv1(x)
        if self.act_name == 'tanh':
            x = F.tanh(x)
        if self.act_name == 'relu':
            x = F.relu(x)
        if self.act_name == 'sigmoid':
            x = F.sigmoid(x)
        # x = self.batchnorm2d1(x)
        # x = F.tanh(self.conv2(x))
        # x = self.batchnorm2d2(x)
        # x = F.tanh(self.conv3(x))
        # x = self.batchnorm2d3(x)
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        # x = F.relu(self.fc1(x))
        x = self.fc1(x)
        # x = nn.Softmax(dim=1)(x)
        return x

# def weight_normal_init(m):
#         if isinstance(m, nn.Linear):
#             nn.init.normal_(m.weight.data,0,0.00005)
#             nn.init.normal_(m.bias.data,0,0.00005)

#         if isinstance(m,nn.Conv2d):
#             nn.init.normal_(m.weight.data,0,0.00005)
#             nn.init.normal_(m.bias.data,0,0.00005)


def gain_model_weight_vector(model):
    weight = []
    for key, value in model.state_dict().items():
        if 'conv' in key:
            if 'weight' in key:
                temp = copy.deepcopy(value.cpu().detach().flatten(start_dim=2))
                weight.append(temp.flatten(start_dim=0, end_dim=1).numpy())
        # if 'bias' in key:
        #     tmp = torch.unsqueeze(copy.deepcopy(value),1)
        #     bias.append(tmp.cpu().detach().numpy())
    # print(bias[0].shape)
    return weight


net = Net(hyperpara['input_channel'], hyperpara['middle_channel'],
          hyperpara['kernel_size'], hyperpara['padding'], hyperpara['convlayernum'],hyperpara['act'])
# print(net)

# net.apply(weight_normal_init)
net = net.to(device)
torch.save(net.state_dict(), os.path.join(savedir, 'epoch=0.pt'))
# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), hyperpara['learning_rate'], momentum=0)
# optimizer = optim.Adam(net.parameters(), hyperpara['learning_rate'])
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[1000],gamma=100)
R = {}
R['losses'] = []
R['weightvector'] = []
R['weight1'] = []   
# R['weight2'] = []
# R['weight3'] = []
R['bias'] = []
for i, data in enumerate(trainloader, 0):      
    # print(data[0].shape)
    # print(data[1][0:10])
    onehot = torch.zeros(1,10)
    for s in data[1]:
        temp = label2onehot(s)  
        onehot = torch.cat((onehot,temp),dim=0)
    onehot = onehot[1:]
    # print(onehot)
    # print(onehot.shape)
    onehot = onehot.to(device)
    labels = torch.unsqueeze(data[1],dim=1)
    # print(labels.shape)
    # print(labels.type())
    labels = labels.float()
    # print(labels.type())
    inputs,labels = data[0].to(device), labels.to(device)
        # print(net.conv1.weight.data.flatten(start_dim=1).shape)
weightvector = gain_model_weight_vector(net)
R['weightvector'].append(weightvector)
for epoch in range(10000):  # loop over the dataset multiple times
    # print(net(inputs).shape)
    if epoch % 100 == 99:
        torch.save(net.state_dict(), os.path.join(
            savedir, 'epoch=%s.pt' % (epoch+1)))
        print(loss.item())
        # print(1)
    # if epoch % 10 == 0:
    # if epoch % 10 == 9:
        weightvector = gain_model_weight_vector(net)
        R['weightvector'].append(weightvector)
        # R['bias'].append(copy.deepcopy(net.conv1.bias.data.cpu()))
        R['weight1'].append(copy.deepcopy(net.conv1.weight.data.cpu()))
        # R['weight2'].append(copy.deepcopy(net.conv1.weight.data.cpu()[1][0][0][0]))  
        # R['weight3'].append(copy.deepcopy(net.conv1.weight.data.cpu()[2][0][0][0]))  
    running_loss = 0.0

        # get the inputs; data is a list of [inputs, labels]
        # inputs, labels = data[0].to(device), data[1].to(device)


        # zero the parameter gradients
    optimizer.zero_grad()
    
    # forward + backward + optimize
    # print(inputs.shape)
    # print(labels)
    outputs = net(inputs)
    # print(outputs.shape)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    # scheduler.step()
    R['losses'].append(loss.item())

with open(os.path.join(savedir, 'trainpro.pkl'), 'wb') as f:
    pickle.dump(R, f)
    f.close()
plt.plot(R['losses'])
plt.yscale('log')
plt.xscale('log')
plt.savefig(os.path.join(savedir, 'losses.png'))
plt.show()
print('Finished Training')
