from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k, multi30k
from typing import Iterable, List

multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}

# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])


# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)


from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
import numpy as np
import copy
import os
import matplotlib.pyplot as plt
import pandas as pd

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
            self.tgt_tok_emb(tgt)), memory,
            tgt_mask)

#### Structure
### Encoder
## first encoder layer
# 0 & 1: multi-head attention input weight & bias
# 2 & 3: multi-head attention output weight & bias
# 4 & 5: linear 1 weight & bias
# 6 & 7: linear 2 weight & bias
# 8 & 9: layer norm 1 weight & bias
# 10 & 11: layer norm 2 weight & bias

## second encoder layer
# 12 & 13: multi-head attention input weight & bias
# 14 & 15: multi-head attention output weight & bias
# 16 & 17: linear 1 weight & bias
# 18 & 19: linear 2 weight & bias
# 20 & 21: layer norm 1 weight & bias
# 22 & 23: layer norm 2 weight & bias

## third encoder layer
# 24 & 25: multi-head attention input weight & bias
# 26 & 27: multi-head attention output weight & bias
# 28 & 29: linear 1 weight & bias
# 30 & 31: linear 2 weight & bias
# 32 & 33: layer norm 1 weight & bias
# 34 & 35: layer norm 2 weight & bias

## 36 & 37: Layer norm

### Decoder
## first decoder layer
# 38 & 39: first multi-head attention input weight & bias
# 40 & 41: first multi-head attention output weight & bias
# 42 & 43: second multi-head attention input weight & bias
# 44 & 45: second multi-head attention output weight & bias
# 46 & 47: linear 1 weight & bias
# 48 & 49: linear 2 weight & bias
# 50 & 51: layer norm 1 weight & bias
# 52 & 53: layer norm 2 weight & bias
# 54 & 55: layer norm 3 weight & bias

## second encoder layer
# 56 & 57: first multi-head attention input weight & bias
# 58 & 59: first multi-head attention output weight & bias
# 60 & 61: second multi-head attention input weight & bias
# 62 & 63: second multi-head attention output weight & bias
# 64 & 65: linear 1 weight & bias
# 66 & 67: linear 2 weight & bias
# 68 & 69: layer norm 1 weight & bias
# 70 & 71: layer norm 2 weight & bias
# 72 & 73: layer norm 3 weight & bias

## third encoder layer
# 74 & 75: first multi-head attention input weight & bias
# 76 & 77: first multi-head attention output weight & bias
# 78 & 79: second multi-head attention input weight & bias
# 80 & 81: second multi-head attention output weight & bias
# 82 & 83: linear 1 weight & bias
# 84 & 85: linear 2 weight & bias
# 86 & 87: layer norm 1 weight & bias
# 88 & 89: layer norm 2 weight & bias
# 90 & 91: layer norm 3 weight & bias

## 92 & 93: Layer norm

## 94 & 95: generator weight & bias
## 96: TokenEmbedding
## 97: TokenEmbedding


def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


######################################################################
# Let's now define the parameters of our model and instantiate the same. Below, we also
# define our loss function which is the cross-entropy loss and the optmizer used for training.
#
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 512
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)


loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
learning_rate_ada = 0.0001
learning_rate_sgd = 0.1


from torch.nn.utils.rnn import pad_sequence


# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input

    return func


# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))


# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln],  # Tokenization
                                               vocab_transform[ln],  # Numericalization
                                               tensor_transform)  # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch


######################################################################
# Let's define training and evaluation loop that will be called for each
# epoch.
#

from torch.utils.data import DataLoader

def load_model(path, method_name, learning_rate):
    model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).to(DEVICE)
    checkpoint = torch.load(path, map_location=DEVICE)
    # Load for model
    model.load_state_dict(checkpoint['model_state_dict'])
    # Load for optimizer
    if method_name ==  'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    elif method_name == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return model, optimizer


# Calculate H_ii of the coordinates in the index_list in the given layer
def cal_diag_hessian(loss, optimizer, all_params, layer_list, layer, layer_type, index_list, method_name, epoch, save_path):
    grad_params = torch.autograd.grad(loss, all_params[layer_list[layer]], create_graph=True)[0] #all_params[layer_list[layer]].grad #
    row_num = index_list.shape[1]
    states = list(optimizer.state_dict()['state'].values())
    if method_name == 'Adam' and epoch > 1:
        exp_avg_sq_list = states[layer_list[layer]]['exp_avg_sq']
        exp_avg_list = states[layer_list[layer]]['exp_avg']

    if method_name == 'SGD' and epoch > 1:
        momentum_buffer_list = states[layer_list[layer]]['momentum_buffer']

    txt_name = save_path + '/layer' + str(layer_list[layer]) + '_epoch' + str(epoch) + '_' + method_name + '_diag.txt'
    fo = open(txt_name, 'a')

    for r in range(row_num):
        index1 = index_list[layer][r]
        if layer_type == 'conv':
            grad = grad_params[index1[0]][index1[1]][index1[2]][index1[3]].item()
            if method_name == 'Adam' and epoch > 1:
                exp_avg_sq = exp_avg_sq_list[index1[0]][index1[1]][index1[2]][index1[3]].item()
                exp_avg = exp_avg_list[index1[0]][index1[1]][index1[2]][index1[3]].item()

            if method_name == 'SGD' and epoch > 1:
                momentum_buffer = momentum_buffer_list[index1[0]][index1[1]][index1[2]][index1[3]].item()

            h = torch.autograd.grad(grad_params[index1[0]][index1[1]][index1[2]][index1[3]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]][index1[1]][index1[2]][index1[3]].item()
            del h
        elif layer_type == 'bn':
            grad = grad_params[index1[0]].item()
            if method_name == 'Adam' and epoch > 1:
                exp_avg_sq = exp_avg_sq_list[index1[0]].item()
                exp_avg = exp_avg_list[index1[0]].item()

            if method_name == 'SGD' and epoch > 1:
                momentum_buffer = momentum_buffer_list[index1[0]].item()

            h = torch.autograd.grad(grad_params[index1[0]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]].item()
            del h
        else:
            grad = grad_params[index1[0]][index1[1]].item()
            if method_name == 'Adam' and epoch > 1:
                exp_avg_sq = exp_avg_sq_list[index1[0]][index1[1]].item()
                exp_avg = exp_avg_list[index1[0]][index1[1]].item()

            if method_name == 'SGD' and epoch > 1:
                momentum_buffer = momentum_buffer_list[index1[0]][index1[1]].item()

            h = torch.autograd.grad(grad_params[index1[0]][index1[1]], all_params[layer_list[layer]], retain_graph=True)[0]
            diagH = h[index1[0]][index1[1]].item()
            del h

        if method_name == 'SGD' and epoch == 1:
            fo.write(str(diagH) + '\t' + str(grad) + '\t' + str(0) + '\n')
        elif method_name == 'SGD' and epoch > 1:
            fo.write(str(diagH) + '\t' + str(grad) + '\t' + str(momentum_buffer) + '\n')
        elif method_name ==  'Adam' and epoch == 1:
            fo.write(str(diagH) + '\t' + str(grad) + '\t' + str(0) + '\t' + str(0) + '\n')
        elif method_name ==  'Adam' and epoch > 1:
            fo.write(str(diagH) + '\t' + str(grad) + '\t' + str(exp_avg_sq) + '\t' + str(exp_avg) + '\n')

    fo.close()


def cal_hessian(epoch, model, optimizer, layer_list, index, method_name, save_path):
    model.eval()
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    all_params = optimizer.param_groups[0]['params']
    layer_num = len(layer_list)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward(retain_graph=True)

        for layer in range(layer_num):
            cal_diag_hessian(loss, optimizer, all_params, layer_list, layer, 'fc', index, method_name, epoch, save_path)

        break


from timeit import default_timer as timer


epoch_list = [1,31,56]

# fc_layers = [0, 2, 4, 6, 12, 14, 16, 18, 24, 26, 28, 30, 38, 40, 42, 44, 46, 48, 56, 58, 60, 62, 64, 66, 74, 76, 78, 80, 82, 84]
fc_layers = [4, 12, 16, 24, 30, 42, 48, 60, 66, 80]
layer_num = len(fc_layers)
row_num = 200  # the number of coordinates we want to sample

init_model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).to(DEVICE)
init_optimizer = torch.optim.SGD(init_model.parameters(), lr=learning_rate_sgd)

all_params = init_optimizer.param_groups[0]['params']

# randomly sample row_num coordinates per layer and store their indexes into index_fc
index_fc = np.zeros((layer_num, row_num, 2), dtype=int)
path = './translation_results/'
for layer in range(layer_num):
    df_read = pd.read_csv('rand_coord_translation/rand_coord_layer' + str(fc_layers[layer]) + '.csv')
    index_fc[layer] = df_read.values


save_path = path+'diagHessian_adaptGrad_'+str(row_num)
if not os.path.exists(save_path):
    os.mkdir(save_path)
for method_name in ['SGD','Adam']:
    for epoch in epoch_list:
        start_time = timer()
        load_path = path + method_name + '_epoch' + str(epoch) + '_batch1024.pth.tar'
        model, optimizer = load_model(load_path, method_name, 0.01)
        cal_hessian(epoch, model, optimizer, fc_layers, index_fc, method_name, save_path)
        end_time = timer()
        print(method_name, end=', ')
        print((f"Epoch: {epoch}, "f"Epoch time = {(end_time - start_time):.3f}s"))
