import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import torch.optim as optim
from tool import *
from datetime import datetime
import shutil
import json
import copy
from tool import setup_seed
setup_seed(10)
basedir = '/home/***/data/undergraky/condenseconvcrossentropy'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
hyperpara = {}
# torch.manual_seed(10)
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4
train_size = 500
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_set = Subset(trainset, range(train_size))
# for (x,y) in train_set:
#     print(x.shape)
#     print(y)
#     x,y = x.to(device),y.to(device) 
# print(train_set[0])
trainloader = torch.utils.data.DataLoader(train_set, batch_size=train_size,
                                          shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')




hyperpara['input_channel'] = 3
hyperpara['middle_channel'] = 32
hyperpara['learning_rate'] = 2e-6
hyperpara['padding'] = 'valid'
hyperpara['kernel_size'] = 5
hyperpara['convlayernum'] = 3
hyperpara['act'] = 'tanh'



class Net(nn.Module):
    def __init__(self, input_channel, middle_channel, kernel_size, padding, convlayernum,act_name):
        super().__init__()
        self.input_channel = input_channel
        self.middle_channel = middle_channel
        self.kernel_size = kernel_size
        self.padding = padding
        self.convlayernum = convlayernum
        self.act_name = act_name

        self.conv1 = nn.Conv2d(
            self.input_channel, self.middle_channel, self.kernel_size, padding=self.padding)
        self.conv2 = nn.Conv2d(
            self.middle_channel, self.middle_channel, self.kernel_size, padding=self.padding)
        self.conv3 = nn.Conv2d(
            self.middle_channel, self.middle_channel, self.kernel_size, padding=self.padding)
        self.fc1 = nn.Linear(
            self.middle_channel*(32-self.convlayernum*(self.kernel_size-1))**2, 1024)
        self.fc2 = nn.Linear(1024,10)

        self.batchnorm2d1 = nn.BatchNorm2d(self.middle_channel)
        self.batchnorm2d2 = nn.BatchNorm2d(self.middle_channel)
        self.batchnorm2d3 = nn.BatchNorm2d(self.middle_channel)

        if self.act_name == 'tanh':
            self.af = nn.Tanh()
        if self.act_name == 'relu':
            self.af = nn.ReLU()
        if self.act_name == 'sigmoid':
            self.af = nn.Sigmoid()

        for obj in self.modules():
            if isinstance(obj, nn.Conv2d):
                nn.init.normal_(obj.weight.data, 0, ((
                    obj.in_channels+obj.out_channels)*self.kernel_size**2/2)**(-2))
                if obj.bias is not None:
                    nn.init.normal_(obj.bias.data, 0, ((
                        obj.in_channels+obj.out_channels)*self.kernel_size**2/2)**(-2))
            if isinstance(obj, nn.Linear):
                nn.init.normal_(obj.weight.data, 0, 0.0001)
                if obj.bias is not None:
                    nn.init.normal_(obj.bias.data, 0, 0.0001)

    def forward(self, x):
        x = self.conv1(x)
        x = self.af(x)
        # x = self.batchnorm2d1(x)
        x = self.af(self.conv2(x))
        # x = self.batchnorm2d2(x)
        x = self.af(self.conv3(x))
        # x = self.batchnorm2d3(x)
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

# def weight_normal_init(m):
#         if isinstance(m, nn.Linear):
#             nn.init.normal_(m.weight.data,0,0.00005)
#             nn.init.normal_(m.bias.data,0,0.00005)

#         if isinstance(m,nn.Conv2d):
#             nn.init.normal_(m.weight.data,0,0.00005)
#             nn.init.normal_(m.bias.data,0,0.00005)


def gain_model_weight_vector(model):
    weight = []
    for key, value in model.state_dict().items():
        if 'conv' in key:
            if 'weight' in key:
                temp = copy.deepcopy(value.cpu().detach().flatten(start_dim=2))
                weight.append(temp.flatten(start_dim=0, end_dim=1).numpy())
        # if 'bias' in key:
        #     tmp = torch.unsqueeze(copy.deepcopy(value),1)
        #     bias.append(tmp.cpu().detach().numpy())
    # print(bias[0].shape)
    return weight


subFolderName = '%s%s' % (datetime.now().strftime("%y%m%d%H%M%S"),hyperpara['act'])
savedir = makedir(os.path.join(basedir, subFolderName))
shutil.copy(__file__, savedir)


# save hyperparameters
para_file = open("%s/para.txt" % (savedir), "w")
para_str = json.dumps(hyperpara, indent=0)
para_file.write(para_str)
para_file.close()
net = Net(hyperpara['input_channel'], hyperpara['middle_channel'],
        hyperpara['kernel_size'], hyperpara['padding'], hyperpara['convlayernum'],hyperpara['act'])
# net.apply(weight_normal_init)
net = net.to(device)
torch.save(net.state_dict(), os.path.join(savedir, 'epoch=0.pt'))
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), hyperpara['learning_rate'], momentum=0)
# scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=50000,gamma=0.05)
optimizer = optim.Adam(net.parameters(), hyperpara['learning_rate'])
R = {}
R['losses'] = []
R['accuracy'] = []
R['weightvector'] = []
R['weight1'] = []   
R['weight2'] = []
R['weight3'] = []
R['bias1'] = []
R['bias2'] = []
R['bias3'] = []
weightvector = gain_model_weight_vector(net)
R['weightvector'].append(weightvector)
for i, data in enumerate(trainloader, 0):        
    inputs,labels = data[0].to(device), data[1].to(device)
        # print(net.conv1.weight.data.flatten(start_dim=1).shape)
for epoch in range(15000):  # loop over the dataset multiple times
    if epoch % 100 == 99:
        torch.save(net.state_dict(), os.path.join(
            savedir, 'epoch=%s.pt' % (epoch+1)))

        # print(1)
    # if epoch % 10 == 0:
        weightvector = gain_model_weight_vector(net)
        R['weightvector'].append(weightvector)
        R['bias1'].append(copy.deepcopy(net.conv1.bias.data.cpu()))
        R['bias2'].append(copy.deepcopy(net.conv2.bias.data.cpu()))
        R['bias3'].append(copy.deepcopy(net.conv3.bias.data.cpu()))
        R['weight1'].append(copy.deepcopy(net.conv1.weight.data.cpu()))
        R['weight2'].append(copy.deepcopy(net.conv2.weight.data.cpu()))  
        R['weight3'].append(copy.deepcopy(net.conv3.weight.data.cpu())) 
        # R['weight3'].append(copy.deepcopy(net.conv1.weight.data.cpu()[2][0][0][0]))  
    running_loss = 0.0

        # get the inputs; data is a list of [inputs, labels]
        # inputs, labels = data[0].to(device), data[1].to(device)


        # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = net(inputs)
    # print(outputs.shape)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    # scheduler.step()
    R['losses'].append(loss.item())
    
    # calculate accuracy
    prediction = torch.argmax(outputs,1)
    correct = (prediction==labels).sum().float()
    total = len(labels)
    acc = correct/total
    print(acc)
    print(total)
    R['accuracy'].append(acc)
with open(os.path.join(savedir, 'trainpro.pkl'), 'wb') as f:
    pickle.dump(R, f)
    f.close()
plt.plot(R['losses'])
plt.yscale('log')
plt.xscale('log')
plt.savefig(os.path.join(savedir, 'losses.png'))
plt.show()
print('Finished Training')
