import os
import gc
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data.dataset import TensorDataset

def load_data(path, suffix, batch_size = 128):
    loc_train, loc_test, vel_train, vel_test, edges_train, edges_test = train_test_split(
        np.load(os.path.join(path, 'loc_' + suffix + '.npy')),
        np.load(os.path.join(path, 'vel_' + suffix + '.npy')),
        np.load(os.path.join(path, 'edges_' + suffix + '.npy')),
        test_size=0.2, random_state=42)

    # [num_samples, num_timesteps, num_dims, num_atoms]
    num_atoms = loc_train.shape[3]

    loc_max = loc_train.max()
    loc_min = loc_train.min()
    vel_max = vel_train.max()
    vel_min = vel_train.min()

    # Normalize to [-1, 1]
    loc_train = (loc_train - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_train = (vel_train - vel_min) * 2 / (vel_max - vel_min) - 1

    loc_test = (loc_test - loc_min) * 2 / (loc_max - loc_min) - 1
    vel_test = (vel_test - vel_min) * 2 / (vel_max - vel_min) - 1


    loc_train = loc_train.reshape(*loc_train.shape[:2], -1)
    vel_train = vel_train.reshape(*vel_train.shape[:2], -1)
    feat_train = np.concatenate([loc_train, vel_train], axis=-1)
    edges_train = np.reshape(edges_train, [-1, num_atoms ** 2])
    edges_train = np.array((edges_train + 1) / 2, dtype=np.int64)

    loc_test = loc_test.reshape(*loc_test.shape[:2], -1)
    vel_test = vel_test.reshape(*vel_test.shape[:2], -1)
    feat_test = np.concatenate([loc_test, vel_test], axis=-1)
    edges_test = np.reshape(edges_test, [-1, num_atoms ** 2])
    edges_test = np.array((edges_test + 1) / 2, dtype=np.int64)

    feat_train = torch.FloatTensor(feat_train)
    edges_train = torch.FloatTensor(edges_train)
    feat_test = torch.FloatTensor(feat_test)
    edges_test = torch.FloatTensor(edges_test)

    off_diag_idx = np.ravel_multi_index(
        np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
        [num_atoms, num_atoms])
    edges_train = edges_train[:, off_diag_idx]
    edges_test = edges_test[:, off_diag_idx]

    train_data = TensorDataset(feat_train, edges_train)
    test_data = TensorDataset(feat_test, edges_test)

    train_data_loader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_data_loader = Data.DataLoader(test_data, batch_size=batch_size)

    return train_data_loader, test_data_loader

class blocked_grad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(x, mask)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        x, mask = ctx.saved_tensors
        return grad_output * mask, mask * 0.0

class GroupLinearLayer(nn.Module):
    def __init__(self, din, dout, num_blocks):
        super(GroupLinearLayer, self).__init__()

        self.w = nn.Parameter(0.01 * torch.randn(num_blocks,din,dout))

    def forward(self,x):
        # torch.permute(input, dims)
        # Returns a view of the original tensor input with its dimensions permuted.
        # dims (tuple of python:ints) – The desired ordering of dimensions
        # >>> x = torch.randn(2, 3, 5)
        # >>> x.size()
        # torch.Size([2, 3, 5])
        # >>> torch.permute(x, (2, 0, 1)).size()
        # torch.Size([5, 2, 3])
        # Two permute will not change the values
        x = x.permute(1,0,2)
        
        # torch.bmm(input, mat2, *, out=None)
        # Performs a batch matrix-matrix product of matrices stored in input and mat2.
        # input and mat2 must be 3-D tensors each containing the same number of matrices.
        # First column remains the same and the second and the third ones are matrix product.
        x = torch.bmm(x,self.w)
        return x.permute(1,0,2)

class GroupLSTMCell(nn.Module):
    """
    GroupLSTMCell can compute the operation of N LSTM Cells at once.
    """
    def __init__(self, inp_size, hidden_size, num_lstms):
        super().__init__()
        self.inp_size = inp_size
        self.hidden_size = hidden_size
        
        self.i2h = GroupLinearLayer(inp_size, 4 * hidden_size, num_lstms)
        self.h2h = GroupLinearLayer(hidden_size, 4 * hidden_size, num_lstms)
        self.reset_parameters()


    def reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, hid_state):
        """
        input: x (batch_size, num_lstms, input_size)
            hid_state (tuple of length 2 with each element of size (batch_size, num_lstms, hidden_state))
        output: h (batch_size, num_lstms, hidden_state)
                c ((batch_size, num_lstms, hidden_state))
        """
        h, c = hid_state
        
        # preact.shape = (max(self.inp_size, self.hidden_size), 4 * hidden_size, num_lstms)
        preact = self.i2h(x) + self.h2h(h)

        # Define n = num_lstms // self.hidden_size
        # preact = [ [xxx], [xxx], [0 | 1 | 2 | ... | n] ]   # len(i) = self.hidden_size
        
        # gates = sigmoid([[xxx], [xxx], [0 | 1 | 2]])
        gates = preact[:, :,  :3 * self.hidden_size].sigmoid()
        
        # g_t = tanh([[xxx], [xxx], [3 | 4 | ... | n]])
        g_t = preact[:, :,  3 * self.hidden_size:].tanh()
        
        # i_t = sigmoid([[xxx], [xxx], [0]])
        i_t = gates[:, :,  :self.hidden_size]
        
        # f_t = sigmoid([[xxx], [xxx], [1]])
        f_t = gates[:, :, self.hidden_size:2 * self.hidden_size]
        
        # o_t = sigmoid([[xxx], [xxx], [2]])
        o_t = gates[:, :, -self.hidden_size:]

        # torch.mul(input, other, *, out=None) Matrix multiplication
        c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) 
        h_t = torch.mul(o_t, c_t.tanh())

        return h_t, c_t

class GroupGRUCell(nn.Module):
    """
    GroupGRUCell can compute the operation of N GRU Cells at once.
    """
    def __init__(self, input_size, hidden_size, num_grus):
        super(GroupGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.x2h = GroupLinearLayer(input_size, 3 * hidden_size, num_grus)
        self.h2h = GroupLinearLayer(hidden_size, 3 * hidden_size, num_grus)
        self.reset_parameters()



    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data = torch.ones(w.data.size())#.uniform_(-std, std)
    
    def forward(self, x, hidden):
        """
        input: x (batch_size, num_grus, input_size)
            hidden (batch_size, num_grus, hidden_size)
        output: hidden (batch_size, num_grus, hidden_size)
        """
        gate_x = self.x2h(x) 
        gate_h = self.h2h(hidden)
        
        # torch.chunk(input, chunks, dim=0) → List of Tensors
        # >>> torch.arange(24).reshape(3,2,4)
        # tensor([[[ 0,  1,  2,  3],
        #         [ 4,  5,  6,  7]],

        #         [[ 8,  9, 10, 11],
        #         [12, 13, 14, 15]],

        #         [[16, 17, 18, 19],
        #         [20, 21, 22, 23]]])
        
        # >>> torch.arange(24).reshape(3,2,4).chunk(3, 2)
        # (tensor([[[ 0,  1],
        #         [ 4,  5]],
        
        #         [[ 8,  9],
        #         [12, 13]],
        
        #         [[16, 17],
        #         [20, 21]]]),
        # tensor([[[ 2,  3],
        #         [ 6,  7]],
        
        #         [[10, 11],
        #         [14, 15]],
        
        #         [[18, 19],
        #         [22, 23]]]))
        i_r, i_i, i_n = gate_x.chunk(3, 2)
        h_r, h_i, h_n = gate_h.chunk(3, 2)
        
        
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + (resetgate * h_n))
        
        hy = newgate + inputgate * (hidden - newgate)
        
        return hy

# Key function: when inputting a new x into the system, it updates hidden state to a new hidden state
class RIMCell(nn.Module):
    def __init__(self, 
        device, input_size, hidden_size, num_units, k, rnn_cell, input_key_size = 64, input_value_size = 400, input_query_size = 64,
        num_input_heads = 1, input_dropout = 0.1, comm_key_size = 32, comm_value_size = 100, comm_query_size = 32, num_comm_heads = 4, comm_dropout = 0.1
    ):
        super().__init__()
        if comm_value_size != hidden_size:
            #print('INFO: Changing communication value size to match hidden_size')
            comm_value_size = hidden_size
        self.device = device
        self.hidden_size = hidden_size
        self.num_units =num_units
        self.rnn_cell = rnn_cell
        self.key_size = input_key_size
        self.k = k
        self.num_input_heads = num_input_heads
        self.num_comm_heads = num_comm_heads
        self.input_key_size = input_key_size
        self.input_query_size = input_query_size
        self.input_value_size = input_value_size

        self.comm_key_size = comm_key_size
        self.comm_query_size = comm_query_size
        self.comm_value_size = comm_value_size

        self.key = nn.Linear(input_size, num_input_heads * input_query_size).to(self.device)
        self.value = nn.Linear(input_size, num_input_heads * input_value_size).to(self.device)

        if self.rnn_cell == 'GRU':
            self.rnn = GroupGRUCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size,  input_key_size * num_input_heads, self.num_units)
        else:
            self.rnn = GroupLSTMCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size,  input_key_size * num_input_heads, self.num_units)
        self.query_ = GroupLinearLayer(hidden_size, comm_query_size * num_comm_heads, self.num_units) 
        self.key_ = GroupLinearLayer(hidden_size, comm_key_size * num_comm_heads, self.num_units)
        self.value_ = GroupLinearLayer(hidden_size, comm_value_size * num_comm_heads, self.num_units)
        self.comm_attention_output = GroupLinearLayer(num_comm_heads * comm_value_size, comm_value_size, self.num_units)
        self.comm_dropout = nn.Dropout(p =input_dropout)
        self.input_dropout = nn.Dropout(p =comm_dropout)

    def transpose_for_scores(self, x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def input_attention_mask(self, x, h):
        """
        Input : x (batch_size, 2, input_size) [The null input is appended along the first dimension]
                h (batch_size, num_units, hidden_size)
        Output: inputs (list of size num_units with each element of shape (batch_size, input_value_size))
                mask_ binary array of shape (batch_size, num_units) where 1 indicates active and 0 indicates inactive
        """
        # self.key(x), self.value(x): Class nn.Linear
        # self.query(x): Class GroupLinearLayer
        key_layer = self.key(x)
        value_layer = self.value(x)
        query_layer = self.query(h)

        key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)
        value_layer = torch.mean(self.transpose_for_scores(value_layer, self.num_input_heads, self.input_value_size), dim = 1)
        query_layer = self.transpose_for_scores(query_layer, self.num_input_heads, self.input_query_size)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / np.sqrt(self.input_key_size) 
        attention_scores = torch.mean(attention_scores, dim = 1)

        mask_ = torch.zeros(x.size(0), self.num_units).to(self.device)

        not_null_scores = attention_scores[:,:, 0]
        # torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
        # Returns the k largest elements of the given input tensor along a given dimension.
        topk1 = torch.topk(not_null_scores, self.k, dim = 1)
        row_index = np.arange(x.size(0))
        row_index = np.repeat(row_index, self.k)

        # Pairwisely assign 1
        mask_[row_index, topk1.indices.view(-1)] = 1
        
        attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores))
        inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2)

        return inputs, mask_

    def communication_attention(self, h, mask):
        """
        Input : h (batch_size, num_units, hidden_size)
                mask obtained from the input_attention_mask() function
        Output: context_layer (batch_size, num_units, hidden_size). New hidden states after communication
        """
        query_layer = []
        key_layer = []
        value_layer = []
        
        query_layer = self.query_(h)
        key_layer = self.key_(h)
        value_layer = self.value_(h)

        query_layer = self.transpose_for_scores(query_layer, self.num_comm_heads, self.comm_query_size)
        key_layer = self.transpose_for_scores(key_layer, self.num_comm_heads, self.comm_key_size)
        value_layer = self.transpose_for_scores(value_layer, self.num_comm_heads, self.comm_value_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / np.sqrt(self.comm_key_size)
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        mask = [mask for _ in range(attention_probs.size(1))]
        # torch.stack(tensors, dim=0, *, out=None) → Tensor
        # Concatenates a sequence of tensors along a new dimension.
        # All tensors need to be of the same size.
        mask = torch.stack(mask, dim = 1)
        
        attention_probs = attention_probs * mask.unsqueeze(3)
        attention_probs = self.comm_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.num_comm_heads * self.comm_value_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.comm_attention_output(context_layer)
        context_layer = context_layer + h
        
        return context_layer

    def forward(self, x, hs, cs = None):
        """
        Input : x (batch_size, 1 , input_size)
                hs (batch_size, num_units, hidden_size)
                cs (batch_size, num_units, hidden_size)
        Output: new hs, cs for LSTM
                new hs for GRU
        """
        size = x.size()
        null_input = torch.zeros(size[0], 1, size[2]).float().to(self.device)
        x = torch.cat((x, null_input), dim = 1)

        # Compute input attention
        inputs, mask = self.input_attention_mask(x, hs)
        h_old = hs * 1.0
        if cs is not None:
            c_old = cs * 1.0
        

        # Compute RNN(LSTM or GRU) output
        
        if cs is not None:
            hs, cs = self.rnn(inputs, (hs, cs))
        else:
            hs = self.rnn(inputs, hs)

        # Block gradient through inactive units
        mask = mask.unsqueeze(2)
        h_new = blocked_grad.apply(hs, mask)

        # Compute communication attention
        h_new = self.communication_attention(h_new, mask.squeeze(2))

        hs = mask * h_new + (1 - mask) * h_old
        if cs is not None:
            cs = mask * cs + (1 - mask) * c_old
            return hs, cs

        return hs, None

class RimModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if args['cuda']:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.rim_model = RIMCell(self.device, args['input_size'], args['hidden_size'], args['num_units'], args['k'], args['rnn_cell'], args['key_size_input'], args['value_size_input'] , args['query_size_input'],
            args['num_input_heads'], args['input_dropout'], args['key_size_comm'], args['value_size_comm'], args['query_size_comm'], args['num_input_heads'], args['comm_dropout']).to(self.device)

        self.Linear = nn.Sequential(
            nn.Linear(args['hidden_size'] * args['num_units'], args['output_size']),
        )

    def to_device(self, x):
        return torch.from_numpy(x).to(self.device)

    def forward(self, x):
        x = x.float()
        
        # initialize hidden states
        hs = torch.randn(x.size(0), self.args['num_units'], self.args['hidden_size']).to(self.device)
        cs = None
        if self.args['rnn_cell'] == 'LSTM':
            cs = torch.randn(x.size(0), self.args['num_units'], self.args['hidden_size']).to(self.device)

        xs = torch.split(x, 1, 1)

        # pass through RIMCell for all timesteps
        for x in xs:
            hs, cs = self.rim_model(x, hs, cs)
        preds = self.Linear(hs.contiguous().view(x.size(0), -1))

        return preds

    def grad_norm(self):
        total_norm = 0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        return total_norm

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def main(version, hidden_size):
    parser = argparse.ArgumentParser()

    parser.add_argument('--cuda', type = str2bool, default = True, help = 'use gpu or not')
    parser.add_argument('--epochs', type = int, default = 50)
    parser.add_argument('--hidden_size', type = int, default = hidden_size)
    parser.add_argument('--input_size', type = int, default = 20)
    parser.add_argument('--output_size', type = int, default = 20)
    parser.add_argument('--num_units', type = int, default = 8)
    parser.add_argument('--rnn_cell', type = str, default = 'LSTM')
    parser.add_argument('--key_size_input', type = int, default = 64)
    parser.add_argument('--value_size_input', type = int, default = 400)
    parser.add_argument('--query_size_input', type = int, default = 64)
    parser.add_argument('--num_input_heads', type = int, default = 1)
    parser.add_argument('--num_comm_heads', type = int, default = 4)
    parser.add_argument('--input_dropout', type = float, default = 0.1)
    parser.add_argument('--comm_dropout', type = float, default = 0.1)

    parser.add_argument('--key_size_comm', type = int, default = 32)
    parser.add_argument('--value_size_comm', type = int, default = 100)
    parser.add_argument('--query_size_comm', type = int, default = 32)
    parser.add_argument('--k', type = int, default = 5)

    args = vars(parser.parse_args())

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    batch_size = 128
    num_epochs = 25

    train_loader, test_loader = load_data('spring', 'springs5', batch_size)

    print(device)

    model = RimModel(args).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader), eta_min = 0.00005, last_epoch = -1)

    criterion = nn.BCEWithLogitsLoss()

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(num_epochs):
        print(version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.train()
        for x, y in tqdm(train_loader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            loss.backward()

            epoch_corrects += int(torch.sum((output > 0).int() == y))
            epoch_loss += loss.item() * x.size(0)
            num_sample += x.size(0) * y.size(1)

            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()


        train_loss.append(epoch_loss / num_sample)
        train_acc.append(epoch_corrects / num_sample)
        print(' ', train_loss[-1], train_acc[-1])

        epoch_loss = 0
        epoch_corrects = 0
        num_sample = 0
        model.eval()
        with torch.no_grad():
            for x, y in tqdm(test_loader):
                x, y = x.to(device), y.to(device)
                output = model(x)
                loss = criterion(output, y)

                epoch_corrects += int(torch.sum((output > 0).int() == y))
                epoch_loss += loss.item() * x.size(0)
                num_sample += x.size(0) * y.size(1)

        test_loss.append(epoch_loss / num_sample)
        test_acc.append(epoch_corrects / num_sample)
        print(' ', test_loss[-1], test_acc[-1])

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss, 'Train Acc': train_acc}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss, 'Test Acc': test_acc}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [128, 256, 512]:
        file_version = f'RIM_Spring_{hidden_size}'
        main(file_version, hidden_size)