import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import pandas as pd


def cifar100(batch_size, num_users):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 先四周填充0，在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转，一半的概率不翻转
        transforms.ToTensor(),
        transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
                             (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),  # R,G,B每层的归一化用到的均值和方差
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5088964127604166, 0.48739301317401956, 0.44194221124387256),
                             (0.2682515741720801, 0.2573637364478126, 0.2770957707973042)),
    ])

    trainset = datasets.CIFAR100(
        root='/kaggle/working/cifar100', train=True, download=True, transform=transform_train)  # 训练数据集
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size * num_users, shuffle=True)  # 生成一个个batch进行批训练，组成batch的时候顺序打乱取

    testset = datasets.CIFAR100(
        root='/kaggle/working/cifar100', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size * num_users, shuffle=False)
    return train_loader, test_loader


class ResidualBlock(nn.Module):
    expansion = 1

    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3,
                      stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3,
                      stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)  # strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def ResNet18(**kwargs):
    return ResNet(ResidualBlock, [2, 2, 2, 2], **kwargs)


def ResNet34(**kwargs):
    return ResNet(ResidualBlock, [3, 4, 6, 3], **kwargs)


class QSGDCompressor(object):
    def __init__(self, size, shape, n_bit, use_cuda):
        self.bit = n_bit
        self.s = pow(2, self.bit - 1) - 1
        self.dim = size
        self.shape = shape
        self.use_cuda = use_cuda
        self.code_dtype = torch.int32

    def compress(self, vec, norm):
        """
        :param vec: torch tensor
        :return: norm, signs, quantized_intervals
        """
        vec = vec.view(-1, self.dim)
        # norm = torch.norm(vec, dim=1, keepdim=True)
        # norm = torch.max(torch.abs(vec), dim=1, keepdim=True)[0]
        normalized_vec = vec / norm

        scaled_vec = torch.abs(normalized_vec) * self.s
        l = scaled_vec.type(self.code_dtype)
        # l = torch.floor(scaled_vec)
        probabilities = scaled_vec - l.type(torch.float32)
        r = torch.rand(l.size())
        if self.use_cuda:
            r = r.cuda()
        l[:] += (probabilities > r).type(self.code_dtype)

        signs = torch.sign(vec) > 0
        return [norm, signs.view(self.shape), l.view(self.shape)]

    def decompress(self, signature):
        [norm, signs, l] = signature
        assert l.shape == signs.shape
        scaled_vec = l.type(torch.float32) * (2 * signs.type(torch.float32) - 1)
        compressed = (scaled_vec.view((-1, self.dim))) * norm / self.s
        return compressed.view(self.shape)


class TernCompressor(object):
    def __init__(self, size, shape, use_cuda):
        self.s = 1
        self.dim = size
        self.shape = shape
        self.use_cuda = use_cuda
        self.code_dtype = torch.int32

    def compress(self, vec, norm):
        """
        :param vec: torch tensor
        :return: norm, signs, quantized_intervals
        """
        vec = vec.view(-1, self.dim)
        # norm = torch.max(torch.abs(vec), dim=1, keepdim=True)[0]
        normalized_vec = vec / norm

        scaled_vec = torch.abs(normalized_vec) * self.s
        l = scaled_vec.type(self.code_dtype)
        # l = torch.floor(scaled_vec)
        probabilities = scaled_vec - l.type(torch.float32)
        r = torch.rand(l.size())
        if self.use_cuda:
            r = r.cuda()
        l[:] += (probabilities > r).type(self.code_dtype)

        signs = torch.sign(vec) > 0
        return [norm, signs.view(self.shape), l.view(self.shape)]

    def decompress(self, signature):
        [norm, signs, l] = signature
        assert l.shape == signs.shape
        scaled_vec = l.type(torch.float32) * (2 * signs.type(torch.float32) - 1)
        compressed = (scaled_vec.view((-1, self.dim))) * norm / self.s
        return compressed.view(self.shape)


class SIGNCompressor(object):
    def __init__(self, size=None, shape=None, args=None):
        pass

    @staticmethod
    def compress(vec, norm):
        return torch.sign(vec)

    @staticmethod
    def decompress(signature):
        return signature


class IdenticalCompressor(object):
    def __init__(self, size=None, shape=None, args=None):
        pass

    @staticmethod
    def compress(vec, norm):
        return vec.clone()

    @staticmethod
    def decompress(signature):
        return signature


class PSQuantizer():
    def __init__(self, parameters, n_bit, use_cuda):
        self.parameters = list(parameters)
        self.num_layers = len(self.parameters)
        self.compressors = list()
        self.compressed_gradients = [list() for _ in range(self.num_layers)]
        for param in self.parameters:
            param_size = param.flatten().shape[0]
            if n_bit == 1.1:
                self.compressors.append(SIGNCompressor())
            elif n_bit == 2.1:
                self.compressors.append(TernCompressor(param_size, param.shape, use_cuda))
            elif n_bit == 32:
                self.compressors.append(IdenticalCompressor())
            else:
                self.compressors.append(
                    QSGDCompressor(param_size, param.shape, n_bit, use_cuda)
                )

    def record(self, norm):
        for i, param in enumerate(self.parameters):
            decompressed_g = self.compressors[i].decompress(
                self.compressors[i].compress(param.grad.data, norm)
            )
            self.compressed_gradients[i].append(decompressed_g)

    def apply(self):
        for i, param in enumerate(self.parameters):
            param.grad.data = torch.stack(self.compressed_gradients[i], dim=0).mean(dim=0)
        for compressed in self.compressed_gradients:
            compressed.clear()


def test(device, model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    return correct / len(test_loader.dataset)


def quantized_train(model, optimizer, loss_func, train_loader, test_loader, device, n_bit, EPOCH, loopNum,
                    NUM_USER):
    all_loss = []
    all_bit = []
    all_acc = []
    all_norm = []
    b = []
    c = []
    train_data = list()
    quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, (data, target) in enumerate(train_loader):
            model.train()
            user_batch_size = len(data) // NUM_USER
            train_data.clear()
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                c.append(norm)
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            all_norm.append(np.mean(c))
            c.clear()
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = test(device, model, test_loader)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def dyquantized_train(model, optimizer, loss_func, train_loader, test_loader, device, EPOCH, loopNum,
                      NUM_USER):
    all_loss = []
    all_norm = []
    all_bit = []
    all_acc = []
    train_data = list()
    b=[]
    c=[]
    k1 = 15
    k2 = 1.0004
    gap=100
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, (data, target) in enumerate(train_loader):
            model.train()
            train_data.clear()
            user_batch_size = len(data) // NUM_USER
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            # determine bits
            if len(all_loss) % gap == 0:
                if len(all_loss) == 0:
                    n_bit = 3
                else:
                    n_bit = np.ceil(np.log2(all_norm[-1] * (k1 * pow(k2, len(all_loss))) + 1)) + 1
                quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                c.append(norm)
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            all_norm.append(np.mean(c))
            c.clear()
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = test(device, model, test_loader)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def adaqsd_train(model, optimizer, loss_func, train_loader, test_loader, device, EPOCH, loopNum,
                NUM_USER):
    all_loss = []
    all_norm = []
    all_bit = []
    all_acc = []
    train_data = list()
    b = []
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    m0 = 0
    v0 = 0
    flag = 1e4
    kaba = 0.4
    n_bit = 2
    gap = 100
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, (data, target) in enumerate(train_loader):
            model.train()
            train_data.clear()
            user_batch_size = len(data) // NUM_USER
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            # determine bits
            if len(all_loss) % gap == 0:
                if len(all_loss) == 0:
                    n_bit = 2
                elif msdr < kaba*flag:
                    n_bit = n_bit + 1
                    flag = msdr
                quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            z = torch.tensor([[0.0]]).cuda()
            parameters = list(model.parameters())
            for para in parameters:
                z = torch.cat([z, para.grad.data.view(-1, 1)], dim=0)
            m0 = beta1 * m0 + (1-beta1)*z
            v0 = beta2 * v0 + (1 - beta2) * (z**2)
            mt = m0 / (1 - pow(beta1, len(all_loss) + 1))
            vt = v0 / (1 - pow(beta2, len(all_loss) + 1))
            msdr = np.linalg.norm(mt.cpu(), ord=2) / (np.linalg.norm((vt ** 0.5).cpu(), ord=2) + epsilon)
            kaba=pow(0.4, (1-len(all_loss)/loopNum))
            all_norm.append(msdr)
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = test(device, model, test_loader)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def main():
    # Hyperparameters
    BATCH_SIZE = 32
    LR = 0.1
    EPOCH = 50
    NUM_REP = 1
    loopNum = 6000
    NUM_USER = 8
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Loading data
    train_loader, test_loader = cifar100(BATCH_SIZE, NUM_USER)

    # training
    # 0：QSGD;   0.1:Adactive;   0.2:AdaQS
    BIT = [32]
    Perform = pd.DataFrame(pd.DataFrame(columns=BIT))
    Accrucy = pd.DataFrame(pd.DataFrame(columns=BIT))
    NUM_BIT = pd.DataFrame(pd.DataFrame(columns=BIT))
    NUM_NORM = pd.DataFrame(pd.DataFrame(columns=BIT))

    for bit in BIT:
        loss_func = torch.nn.CrossEntropyLoss()
        Loss = np.zeros(loopNum)
        Acc = np.zeros(int(loopNum/20))
        for i in range(NUM_REP):
            model = ResNet34().to(device)
            optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
            # optimizer = torch.optim.Adam(model.parameters(), lr=LR)
            if bit > 0.5:
                all_loss, all_norm, all_acc, all_bit = quantized_train(model, optimizer, loss_func, train_loader,
                                                                   test_loader, device, bit, EPOCH, loopNum, NUM_USER)
            elif bit == 0:
                all_loss, all_norm, all_acc, all_bit = dyquantized_train(model, optimizer, loss_func, train_loader,test_loader,
                                                                         device, EPOCH, loopNum, NUM_USER)
            elif bit == -1:
                all_loss, all_norm, all_acc, all_bit = adaqsd_train(model, optimizer, loss_func, train_loader, test_loader, device,
                                                                    EPOCH, loopNum,NUM_USER)
            Loss = Loss + np.array(all_loss)
            Acc = Acc + np.array(all_acc)
        Loss = Loss / NUM_REP
        Acc = Acc / NUM_REP

        Perform.loc[:, bit] = Loss
        Accrucy.loc[:, bit] = Acc
        NUM_BIT.loc[:, bit] = all_bit
        NUM_NORM.loc[:, bit] = all_norm

    filename1 = '/kaggle/working/loss1.csv'
    filename2 = '/kaggle/working/acc1.csv'
    filename3 = '/kaggle/working/bit1.csv'
    filename4 = '/kaggle/working/norm1.csv'
    Perform.to_csv(filename1)
    Accrucy.to_csv(filename2)
    NUM_BIT.to_csv(filename3)
    NUM_NORM.to_csv(filename4)


if __name__ == '__main__':
    main()