import torch
import torch.nn as nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import pandas as pd
from torchtext.data import Iterator, BucketIterator, TabularDataset
from torchtext import data
from torchtext.vocab import Vectors, GloVe


class QSGDCompressor(object):
    def __init__(self, size, shape, n_bit, use_cuda):
        self.bit = n_bit
        self.s = pow(2, self.bit - 1) - 1
        self.dim = size
        self.shape = shape
        self.use_cuda = use_cuda
        self.code_dtype = torch.int32

    def compress(self, vec, norm):
        """
        :param vec: torch tensor
        :return: norm, signs, quantized_intervals
        """
        vec = vec.view(-1, self.dim)
        normalized_vec = vec / norm

        scaled_vec = torch.abs(normalized_vec) * self.s
        l = scaled_vec.type(self.code_dtype)
        probabilities = scaled_vec - l.type(torch.float32)
        r = torch.rand(l.size())
        if self.use_cuda:
            r = r.cuda()
        l[:] += (probabilities > r).type(self.code_dtype)

        signs = torch.sign(vec) > 0
        return [norm, signs.view(self.shape), l.view(self.shape)]

    def decompress(self, signature):
        [norm, signs, l] = signature
        assert l.shape == signs.shape
        scaled_vec = l.type(torch.float32) * (2 * signs.type(torch.float32) - 1)
        compressed = (scaled_vec.view((-1, self.dim))) * norm / self.s
        return compressed.view(self.shape)


class TernCompressor(object):
    def __init__(self, size, shape, use_cuda):
        self.dim = size
        self.shape = shape
        self.use_cuda = use_cuda
        self.code_dtype = torch.int32

    def compress(self, vec, norm):

        vec = vec.view(-1, self.dim)
        normalized_vec = vec / norm
        scaled_vec = torch.abs(normalized_vec)
        l = scaled_vec.type(self.code_dtype)
        probabilities = scaled_vec - l.type(torch.float32)
        r = torch.rand(l.size())
        if self.use_cuda:
            r = r.cuda()
        l[:] += (probabilities > r).type(self.code_dtype)

        signs = torch.sign(vec) > 0
        return [norm, signs.view(self.shape), l.view(self.shape)]

    def decompress(self, signature):
        [norm, signs, l] = signature
        assert l.shape == signs.shape
        scaled_vec = l.type(torch.float32) * (2 * signs.type(torch.float32) - 1)
        compressed = (scaled_vec.view((-1, self.dim))) * norm
        return compressed.view(self.shape)


class SIGNCompressor(object):
    def __init__(self, size=None, shape=None, args=None):
        pass

    @staticmethod
    def compress(vec, norm):
        return torch.sign(vec)

    @staticmethod
    def decompress(signature):
        return signature


class IdenticalCompressor(object):
    def __init__(self, size=None, shape=None, args=None):
        pass

    @staticmethod
    def compress(vec, norm):
        return vec.clone()

    @staticmethod
    def decompress(signature):
        return signature


class PSQuantizer():
    def __init__(self, parameters, n_bit, use_cuda):
        self.parameters = list(parameters)
        self.num_layers = len(self.parameters)
        self.compressors = list()
        self.compressed_gradients = [list() for _ in range(self.num_layers)]
        for param in self.parameters:
            param_size = param.flatten().shape[0]
            if n_bit == 1.1:
                self.compressors.append(SIGNCompressor())
            elif n_bit == 2.1:
                self.compressors.append(TernCompressor(param_size, param.shape, use_cuda))
            elif n_bit == 32:
                self.compressors.append(IdenticalCompressor())
            else:
                self.compressors.append(
                    QSGDCompressor(param_size, param.shape, n_bit, use_cuda)
                )

    def record(self, norm):
        for i, param in enumerate(self.parameters):
            if param.grad is None:
                decompressed_g = 0*param
                self.compressed_gradients[i].append(decompressed_g)
            else:
                decompressed_g = self.compressors[i].decompress(
                    self.compressors[i].compress(param.grad.data,norm)
                )
                self.compressed_gradients[i].append(decompressed_g)

    def apply(self):
        for i, param in enumerate(self.parameters):
            param.grad.data = torch.stack(self.compressed_gradients[i], dim=0).mean(dim=0)
        for compressed in self.compressed_gradients:
            compressed.clear()


def get_data_iter(train_csv, test_csv, BATCH_SIZE, NUM_USER, fix_length):
    TEXT = data.Field(sequential=True, lower=True, fix_length=fix_length, batch_first=True)
    LABEL = data.Field(sequential=False, use_vocab=False)
    train_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    train = TabularDataset(path=train_csv, format="csv", fields=train_fields, skip_header=True)
    train_iter = BucketIterator(train, batch_size=BATCH_SIZE*NUM_USER, device=-1, sort_key=lambda x: len(x.text),
                                sort_within_batch=False, repeat=False)
    test_fields = [("label", LABEL), ("title", None), ("text", TEXT)]
    test = TabularDataset(path=test_csv, format="csv", fields=test_fields, skip_header=True)
    test_iter = Iterator(test, batch_size=BATCH_SIZE, device=-1, sort=False, sort_within_batch=False, repeat=False)

    TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
    vocab = TEXT.vocab
    return train_iter, test_iter, vocab


class LSTMEncoder(nn.Module):
    def __init__(self, vocab, label_size, emb_dim, hidden_size, num_layers, d_a, feat_dim):
        super(LSTMEncoder,self).__init__()
        self.embed = nn.Embedding(len(vocab), emb_dim)
        # 若使用预训练的词向量，需在此处指定预训练的权重
        self.embed.weight.data.copy_(vocab.vectors)
        self.embed.weight.requires_grad = True
        self.bilstm = nn.LSTM(emb_dim, hidden_size, num_layers, bidirectional=True, batch_first=True)
        self.fc1=nn.Linear(hidden_size*2, d_a)
        self.tanh=nn.Tanh()
        self.fc2=nn.Linear(d_a, 1)
        self.softmax=nn.Softmax(dim=2)
        self.fc3=nn.Linear(hidden_size*2, feat_dim)
        self.fc=nn.Linear(feat_dim, label_size)

    def forward(self,x, target=None):

        x=self.embed(x)  # x[bsz, max_len] -> x[bsz, max_len, dim_emb]
        x,(h_n,c_n)=self.bilstm(x)  # x[bsz, max_len, hidden_size*2]
        A=self.tanh(self.fc1(x))   # [bsz, max_len, d_a]
        A=self.softmax(self.fc2(A))  # [bsz, max_len, 1]
        attention=torch.transpose(A,1,2).contiguous()  # attention[bsz, 1, max_len]
        x=x.transpose(1,2)  # x[bsz, dim_emb, max_len]
        sentence_embedding=torch.sum(attention*x,dim=2)  # [bsz, hidden_size*2]
        sentence_embedding=self.fc3(sentence_embedding)  # [bsz, dim_feature]
        output=self.fc(sentence_embedding)  # [bsz, no_class(k in train)]
        return output


def model_test(device, net, test_iter):
    net.eval()  # 必备，将模型设置为训练模式
    correct = 0
    total = 0
    with torch.no_grad():
        for i, batch in enumerate(test_iter):
            # 注意target=batch.label - 1，因为数据集中的label是1，2，3，4，但是pytorch的label默认是从0开始，所以这里需要减1
            data, label = batch.text, batch.label - 1
            data, label = data.to(device), label.to(device)
            outputs = net(data)
            # torch.max()[0]表示最大值的值，troch.max()[1]表示回最大值的每个索引
            _, predicted = torch.max(outputs.data, 1)  # 每个output是一行n列的数据，取一行中最大的值
            total += label.size(0)
            correct += (predicted == label).sum().item()
    return correct / total


def quantized_train(model, optimizer, loss_func, train_loader, test_loader, device, n_bit, EPOCH, loopNum,
                    NUM_USER):
    all_loss = []
    all_bit = []
    all_acc = []
    all_norm = []
    b = []
    c = []
    train_data = list()
    quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, batch in enumerate(train_loader):
            model.train()
            data, target = batch.text, batch.label - 1
            if (data.size()[0] is not 32*NUM_USER):
                continue
            user_batch_size = len(data) // NUM_USER
            train_data.clear()
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                c.append(norm)
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            all_norm.append(np.mean(c))
            c.clear()
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = model_test(device, model, test_loader)
                print(acc)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def dyquantized_train(model, optimizer, loss_func, train_loader, test_loader, device, EPOCH, loopNum,
                    NUM_USER):
    all_loss = []
    all_norm = []
    all_bit = []
    all_acc = []
    train_data = list()
    b = []
    c = []
    k1 = 3
    k2 = 1.003
    gap = 100
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, batch in enumerate(train_loader):
            model.train()
            data, target = batch.text, batch.label - 1
            if (data.size()[0] is not 32*NUM_USER):
                continue
            user_batch_size = len(data) // NUM_USER
            train_data.clear()
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            # determine bits
            if len(all_loss) % gap == 0:
                if len(all_loss) == 0:
                    n_bit = 3
                elif round(np.log2(np.mean(all_norm[len(all_norm) - 50:]) * (k1 * pow(k2, len(all_loss))) + 1)) == 0:
                    n_bit = 2
                else:
                    n_bit = round(
                        np.log2(np.mean(all_norm[len(all_norm) - 50:]) * (k1 * pow(k2, len(all_loss))) + 1)) + 1
                quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                c.append(norm)
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            all_norm.append(np.mean(c))
            c.clear()
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = model_test(device, model, test_loader)
                print(acc)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def adaqsd_train(model, optimizer, loss_func, train_loader, test_loader, device, EPOCH, loopNum,
                NUM_USER):
    all_loss = []
    all_norm = []
    all_bit = []
    all_acc = []
    train_data = list()
    b = []
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    m0 = 0
    v0 = 0
    flag = 1e4
    kaba = 0.4
    n_bit = 2
    gap = 100
    # training...
    for epoch in range(EPOCH):
        print(epoch)
        for step, batch in enumerate(train_loader):
            model.train()
            data, target = batch.text, batch.label - 1
            if (data.size()[0] is not 32 * NUM_USER):
                continue
            user_batch_size = len(data) // NUM_USER
            train_data.clear()
            for user_id in range(NUM_USER - 1):
                train_data.append((data[user_id * user_batch_size:(user_id + 1) * user_batch_size],
                                   target[user_id * user_batch_size:(user_id + 1) * user_batch_size]))
            train_data.append((data[(NUM_USER - 1) * user_batch_size:],
                               target[(NUM_USER - 1) * user_batch_size:]))
            # determine bits
            if len(all_loss) % gap == 0:
                if len(all_loss) == 0:
                    n_bit = 3
                elif msdr < kaba*flag:
                    n_bit = n_bit + 1
                    flag = msdr
                quantizer = PSQuantizer(model.parameters(), n_bit, use_cuda=True)
            for user_id in range(NUM_USER):
                optimizer.zero_grad()
                _x, _y = train_data[user_id]
                x = _x.to(device)
                y = _y.to(device)
                output = model(x)  # cnn output
                loss = loss_func(output, y)  # cross entropy loss
                loss.backward()  # backpropagation, compute gradients
                parameters = list(model.parameters())
                for para in parameters:
                    b.append(np.linalg.norm(para.grad.data.cpu().flatten(), ord=np.inf))
                norm = max(b)
                b.clear()
                quantizer.record(norm)
            quantizer.apply()
            optimizer.step()
            z = torch.tensor([[0.0]]).cuda()
            parameters = list(model.parameters())
            for para in parameters:
                z = torch.cat([z, para.grad.data.view(-1, 1)], dim=0)
            m0 = beta1 * m0 + (1-beta1)*z
            v0 = beta2 * v0 + (1 - beta2) * (z**2)
            mt = m0 / (1 - pow(beta1, len(all_loss) + 1))
            vt = v0 / (1 - pow(beta2, len(all_loss) + 1))
            msdr = np.linalg.norm(mt.cpu(), ord=2) / (np.linalg.norm((vt ** 0.5).cpu(), ord=2) + epsilon)
            kaba=pow(0.4, (1-len(all_loss)/loopNum))
            all_norm.append(msdr)
            all_loss.append(loss.item())
            all_bit.append(n_bit)
            if len(all_loss) % 20 == 0:
                acc = model_test(device, model, test_loader)
                all_acc.append(acc)
            if len(all_loss) == loopNum:
                return all_loss, all_norm, all_acc, all_bit
    return all_loss, all_norm, all_acc, all_bit


def main():
    # Hyperparameters
    BATCH_SIZE = 32
    LR = 0.005
    EPOCH = 50
    NUM_REP = 1
    loopNum = 1000
    NUM_USER = 8
    emb_dim = 300  # 词向量维度
    hidden_size = 200
    label_size = 4
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Loading data
    # train_csv = "data/train.csv"
    # test_csv = "data/test.csv"
    train_csv = "../input/ag-news-classification-dataset/train.csv"
    test_csv = "../input/ag-news-classification-dataset/test.csv"


    # training
    # 0：QSGD;   0.1:Adactive;   0.2:AdaQS
    BIT = [0]
    Perform = pd.DataFrame(pd.DataFrame(columns=BIT))
    Accrucy = pd.DataFrame(pd.DataFrame(columns=BIT))
    NUM_BIT = pd.DataFrame(pd.DataFrame(columns=BIT))
    NUM_NORM = pd.DataFrame(pd.DataFrame(columns=BIT))

    for bit in BIT:
        loss_func = torch.nn.CrossEntropyLoss()
        Loss = np.zeros(loopNum)
        Acc = np.zeros(int(loopNum/20))
        for i in range(NUM_REP):
            train_loader, test_loader, vocab = get_data_iter(train_csv, test_csv, BATCH_SIZE, NUM_USER, fix_length=200)
            model = LSTMEncoder(vocab, label_size, emb_dim, hidden_size, num_layers=2, d_a=128, feat_dim=128).to(device)
            # model = FastText(vocab=vocab, vec_dim=emb_dim, label_size=label_size, hidden_size=hidden_size).to(device)
            # optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
            optimizer = torch.optim.Adam(model.parameters(), lr=LR)
            if bit > 0.5:
                all_loss, all_norm, all_acc, all_bit = quantized_train(model, optimizer, loss_func, train_loader,
                                                                   test_loader, device, bit, EPOCH, loopNum, NUM_USER)
            elif bit == 0:
                all_loss, all_norm, all_acc, all_bit = dyquantized_train(model, optimizer, loss_func, train_loader,test_loader,
                                                                         device, EPOCH, loopNum, NUM_USER)
            elif bit == -1:
                all_loss, all_norm, all_acc, all_bit = adaqsd_train(model, optimizer, loss_func, train_loader, test_loader, device,
                                                                    EPOCH, loopNum,NUM_USER)
            Loss = Loss + np.array(all_loss)
            Acc = Acc + np.array(all_acc)
        Loss = Loss / NUM_REP
        Acc = Acc / NUM_REP

        Perform.loc[:, bit] = Loss
        Accrucy.loc[:, bit] = Acc
        NUM_BIT.loc[:, bit] = all_bit
        NUM_NORM.loc[:, bit] = all_norm

    filename1 = '/kaggle/working/loss1.csv'
    filename2 = '/kaggle/working/acc1.csv'
    filename3 = '/kaggle/working/bit1.csv'
    filename4 = '/kaggle/working/norm1.csv'
    Perform.to_csv(filename1)
    Accrucy.to_csv(filename2)
    NUM_BIT.to_csv(filename3)
    NUM_NORM.to_csv(filename4)


if __name__ == '__main__':
    main()