import gc
import random
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data import Dataset

class blocked_grad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(x, mask)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        x, mask = ctx.saved_tensors
        return grad_output * mask, mask * 0.0

class GroupLinearLayer(nn.Module):
    def __init__(self, din, dout, num_blocks):
        super(GroupLinearLayer, self).__init__()

        self.w = nn.Parameter(0.01 * torch.randn(num_blocks,din,dout))

    def forward(self,x):
        # torch.permute(input, dims)
        # Returns a view of the original tensor input with its dimensions permuted.
        # dims (tuple of python:ints) – The desired ordering of dimensions
        # >>> x = torch.randn(2, 3, 5)
        # >>> x.size()
        # torch.Size([2, 3, 5])
        # >>> torch.permute(x, (2, 0, 1)).size()
        # torch.Size([5, 2, 3])
        # Two permute will not change the values
        x = x.permute(1,0,2)
        
        # torch.bmm(input, mat2, *, out=None)
        # Performs a batch matrix-matrix product of matrices stored in input and mat2.
        # input and mat2 must be 3-D tensors each containing the same number of matrices.
        # First column remains the same and the second and the third ones are matrix product.
        x = torch.bmm(x,self.w)
        return x.permute(1,0,2)

class GroupLSTMCell(nn.Module):
    """
    GroupLSTMCell can compute the operation of N LSTM Cells at once.
    """
    def __init__(self, inp_size, hidden_size, num_lstms):
        super().__init__()
        self.inp_size = inp_size
        self.hidden_size = hidden_size
        
        self.i2h = GroupLinearLayer(inp_size, 4 * hidden_size, num_lstms)
        self.h2h = GroupLinearLayer(hidden_size, 4 * hidden_size, num_lstms)
        self.reset_parameters()


    def reset_parameters(self):
        stdv = 1.0 / np.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, hid_state):
        """
        input: x (batch_size, num_lstms, input_size)
            hid_state (tuple of length 2 with each element of size (batch_size, num_lstms, hidden_state))
        output: h (batch_size, num_lstms, hidden_state)
                c ((batch_size, num_lstms, hidden_state))
        """
        h, c = hid_state
        
        # preact.shape = (max(self.inp_size, self.hidden_size), 4 * hidden_size, num_lstms)
        preact = self.i2h(x) + self.h2h(h)

        # Define n = num_lstms // self.hidden_size
        # preact = [ [xxx], [xxx], [0 | 1 | 2 | ... | n] ]   # len(i) = self.hidden_size
        
        # gates = sigmoid([[xxx], [xxx], [0 | 1 | 2]])
        gates = preact[:, :,  :3 * self.hidden_size].sigmoid()
        
        # g_t = tanh([[xxx], [xxx], [3 | 4 | ... | n]])
        g_t = preact[:, :,  3 * self.hidden_size:].tanh()
        
        # i_t = sigmoid([[xxx], [xxx], [0]])
        i_t = gates[:, :,  :self.hidden_size]
        
        # f_t = sigmoid([[xxx], [xxx], [1]])
        f_t = gates[:, :, self.hidden_size:2 * self.hidden_size]
        
        # o_t = sigmoid([[xxx], [xxx], [2]])
        o_t = gates[:, :, -self.hidden_size:]

        # torch.mul(input, other, *, out=None) Matrix multiplication
        c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) 
        h_t = torch.mul(o_t, c_t.tanh())

        return h_t, c_t

class GroupGRUCell(nn.Module):
    """
    GroupGRUCell can compute the operation of N GRU Cells at once.
    """
    def __init__(self, input_size, hidden_size, num_grus):
        super(GroupGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.x2h = GroupLinearLayer(input_size, 3 * hidden_size, num_grus)
        self.h2h = GroupLinearLayer(hidden_size, 3 * hidden_size, num_grus)
        self.reset_parameters()



    def reset_parameters(self):
        std = 1.0 / np.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data = torch.ones(w.data.size())#.uniform_(-std, std)
    
    def forward(self, x, hidden):
        """
        input: x (batch_size, num_grus, input_size)
            hidden (batch_size, num_grus, hidden_size)
        output: hidden (batch_size, num_grus, hidden_size)
        """
        gate_x = self.x2h(x) 
        gate_h = self.h2h(hidden)
        
        # torch.chunk(input, chunks, dim=0) → List of Tensors
        # >>> torch.arange(24).reshape(3,2,4)
        # tensor([[[ 0,  1,  2,  3],
        #         [ 4,  5,  6,  7]],

        #         [[ 8,  9, 10, 11],
        #         [12, 13, 14, 15]],

        #         [[16, 17, 18, 19],
        #         [20, 21, 22, 23]]])
        
        # >>> torch.arange(24).reshape(3,2,4).chunk(3, 2)
        # (tensor([[[ 0,  1],
        #         [ 4,  5]],
        
        #         [[ 8,  9],
        #         [12, 13]],
        
        #         [[16, 17],
        #         [20, 21]]]),
        # tensor([[[ 2,  3],
        #         [ 6,  7]],
        
        #         [[10, 11],
        #         [14, 15]],
        
        #         [[18, 19],
        #         [22, 23]]]))
        i_r, i_i, i_n = gate_x.chunk(3, 2)
        h_r, h_i, h_n = gate_h.chunk(3, 2)
        
        
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + (resetgate * h_n))
        
        hy = newgate + inputgate * (hidden - newgate)
        
        return hy

# Key function: when inputting a new x into the system, it updates hidden state to a new hidden state
class RIMCell(nn.Module):
    def __init__(self, 
        device, input_size, hidden_size, num_units, k, rnn_cell, input_key_size = 64, input_value_size = 400, input_query_size = 64,
        num_input_heads = 1, input_dropout = 0.1, comm_key_size = 32, comm_value_size = 100, comm_query_size = 32, num_comm_heads = 4, comm_dropout = 0.1
    ):
        super().__init__()
        if comm_value_size != hidden_size:
            #print('INFO: Changing communication value size to match hidden_size')
            comm_value_size = hidden_size
        self.device = device
        self.hidden_size = hidden_size
        self.num_units =num_units
        self.rnn_cell = rnn_cell
        self.key_size = input_key_size
        self.k = k
        self.num_input_heads = num_input_heads
        self.num_comm_heads = num_comm_heads
        self.input_key_size = input_key_size
        self.input_query_size = input_query_size
        self.input_value_size = input_value_size

        self.comm_key_size = comm_key_size
        self.comm_query_size = comm_query_size
        self.comm_value_size = comm_value_size

        self.key = nn.Linear(input_size, num_input_heads * input_query_size).to(self.device)
        self.value = nn.Linear(input_size, num_input_heads * input_value_size).to(self.device)

        if self.rnn_cell == 'GRU':
            self.rnn = GroupGRUCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size,  input_key_size * num_input_heads, self.num_units)
        else:
            self.rnn = GroupLSTMCell(input_value_size, hidden_size, num_units)
            self.query = GroupLinearLayer(hidden_size,  input_key_size * num_input_heads, self.num_units)
        self.query_ = GroupLinearLayer(hidden_size, comm_query_size * num_comm_heads, self.num_units) 
        self.key_ = GroupLinearLayer(hidden_size, comm_key_size * num_comm_heads, self.num_units)
        self.value_ = GroupLinearLayer(hidden_size, comm_value_size * num_comm_heads, self.num_units)
        self.comm_attention_output = GroupLinearLayer(num_comm_heads * comm_value_size, comm_value_size, self.num_units)
        self.comm_dropout = nn.Dropout(p =input_dropout)
        self.input_dropout = nn.Dropout(p =comm_dropout)


    def transpose_for_scores(self, x, num_attention_heads, attention_head_size):
        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def input_attention_mask(self, x, h):
        """
        Input : x (batch_size, 2, input_size) [The null input is appended along the first dimension]
                h (batch_size, num_units, hidden_size)
        Output: inputs (list of size num_units with each element of shape (batch_size, input_value_size))
                mask_ binary array of shape (batch_size, num_units) where 1 indicates active and 0 indicates inactive
        """
        # self.key(x), self.value(x): Class nn.Linear
        # self.query(x): Class GroupLinearLayer
        key_layer = self.key(x)
        value_layer = self.value(x)
        query_layer = self.query(h)

        key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)
        value_layer = torch.mean(self.transpose_for_scores(value_layer, self.num_input_heads, self.input_value_size), dim = 1)
        query_layer = self.transpose_for_scores(query_layer, self.num_input_heads, self.input_query_size)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / np.sqrt(self.input_key_size) 
        attention_scores = torch.mean(attention_scores, dim = 1)

        mask_ = torch.zeros(x.size(0), self.num_units).to(self.device)

        not_null_scores = attention_scores[:,:, 0]
        # torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
        # Returns the k largest elements of the given input tensor along a given dimension.
        topk1 = torch.topk(not_null_scores, self.k, dim = 1)
        row_index = np.arange(x.size(0))
        row_index = np.repeat(row_index, self.k)

        # Pairwisely assign 1
        mask_[row_index, topk1.indices.view(-1)] = 1
        
        attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores))
        inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2)

        return inputs, mask_

    def communication_attention(self, h, mask):
        """
        Input : h (batch_size, num_units, hidden_size)
                mask obtained from the input_attention_mask() function
        Output: context_layer (batch_size, num_units, hidden_size). New hidden states after communication
        """
        query_layer = []
        key_layer = []
        value_layer = []
        
        query_layer = self.query_(h)
        key_layer = self.key_(h)
        value_layer = self.value_(h)

        query_layer = self.transpose_for_scores(query_layer, self.num_comm_heads, self.comm_query_size)
        key_layer = self.transpose_for_scores(key_layer, self.num_comm_heads, self.comm_key_size)
        value_layer = self.transpose_for_scores(value_layer, self.num_comm_heads, self.comm_value_size)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / np.sqrt(self.comm_key_size)
        
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        mask = [mask for _ in range(attention_probs.size(1))]
        # torch.stack(tensors, dim=0, *, out=None) → Tensor
        # Concatenates a sequence of tensors along a new dimension.
        # All tensors need to be of the same size.
        mask = torch.stack(mask, dim = 1)
        
        attention_probs = attention_probs * mask.unsqueeze(3)
        attention_probs = self.comm_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.num_comm_heads * self.comm_value_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        context_layer = self.comm_attention_output(context_layer)
        context_layer = context_layer + h
        
        return context_layer

    def forward(self, x, hs, cs = None):
        """
        Input : x (batch_size, 1 , input_size)
                hs (batch_size, num_units, hidden_size)
                cs (batch_size, num_units, hidden_size)
        Output: new hs, cs for LSTM
                new hs for GRU
        """
        size = x.size()
        null_input = torch.zeros(size[0], 1, size[2]).float().to(self.device)
        x = torch.cat((x, null_input), dim = 1)

        # Compute input attention
        inputs, mask = self.input_attention_mask(x, hs)
        h_old = hs * 1.0
        if cs is not None:
            c_old = cs * 1.0
        

        # Compute RNN(LSTM or GRU) output
        
        if cs is not None:
            hs, cs = self.rnn(inputs, (hs, cs))
        else:
            hs = self.rnn(inputs, hs)

        # Block gradient through inactive units
        mask = mask.unsqueeze(2)
        h_new = blocked_grad.apply(hs, mask)

        # Compute communication attention
        h_new = self.communication_attention(h_new, mask.squeeze(2))

        hs = mask * h_new + (1 - mask) * h_old
        if cs is not None:
            cs = mask * cs + (1 - mask) * c_old
            return hs, cs

        return hs, None

class RIM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if args['cuda']:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.rim_model = RIMCell(self.device, args['input_size'], args['hidden_size'], args['num_units'], args['k'], args['rnn_cell'], args['key_size_input'], args['value_size_input'] , args['query_size_input'],
            args['num_input_heads'], args['input_dropout'], args['key_size_comm'], args['value_size_comm'], args['query_size_comm'], args['num_input_heads'], args['comm_dropout']).to(self.device)

        self.Linear = nn.Sequential(
            nn.Linear(args['hidden_size'] * args['num_units'], args['input_size']),
        )
        self.pred_n = args['pred_n']
        self.Loss = nn.CrossEntropyLoss()

    def to_device(self, x):
        return torch.from_numpy(x).to(self.device)

    def forward(self, x):
        preds = []
        x = x.float()
        
        # initialize hidden states
        hs = torch.randn(x.size(0), self.args['num_units'], self.args['hidden_size']).to(self.device)
        cs = None
        if self.args['rnn_cell'] == 'LSTM':
            cs = torch.randn(x.size(0), self.args['num_units'], self.args['hidden_size']).to(self.device)

        xs = torch.split(x, 1, 1)

        for x in xs:
            hs, cs = self.rim_model(x, hs, cs)

        preds.append(self.Linear(hs.contiguous().view(x.size(0), -1)).unsqueeze(dim=1))

        for _ in range(1, self.pred_n):
            hs, cs = self.rim_model(preds[-1], hs, cs)

            preds.append(self.Linear(hs.contiguous().view(x.size(0), -1)).unsqueeze(dim=1))
        return torch.cat(preds, dim=1)

    def grad_norm(self):
        total_norm = 0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        return total_norm

class ThreeBodyDataset(Dataset):
    def __init__(self, t_path: str, X_path: str, y_path: str):
        self.t = torch.Tensor(np.load(t_path))
        self.X = torch.Tensor(np.load(X_path))
        self.y = torch.Tensor(np.load(y_path))

    def __getitem__(self, index):
        return self.t[index], self.X[index], self.y[index]

    def __len__(self):
        return len(self.t)

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def main(version, hidden_size):
    parser = argparse.ArgumentParser()

    parser.add_argument('--cuda', type = str2bool, default = True, help = 'use gpu or not')
    parser.add_argument('--epochs', type = int, default = 100)
    parser.add_argument('--hidden_size', type = int, default = hidden_size)
    parser.add_argument('--input_size', type = int, default = 9)
    parser.add_argument('--pred_n', type = int, default = 3)
    parser.add_argument('--num_units', type = int, default = 6)
    parser.add_argument('--rnn_cell', type = str, default = 'LSTM')
    parser.add_argument('--key_size_input', type = int, default = 64)
    parser.add_argument('--value_size_input', type = int, default = 400)
    parser.add_argument('--query_size_input', type = int, default = 64)
    parser.add_argument('--num_input_heads', type = int, default = 1)
    parser.add_argument('--num_comm_heads', type = int, default = 4)
    parser.add_argument('--input_dropout', type = float, default = 0.1)
    parser.add_argument('--comm_dropout', type = float, default = 0.1)

    parser.add_argument('--key_size_comm', type = int, default = 32)
    parser.add_argument('--value_size_comm', type = int, default = 100)
    parser.add_argument('--query_size_comm', type = int, default = 32)
    parser.add_argument('--k', type = int, default = 3)

    args = vars(parser.parse_args())

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    batch_size = 128
    num_epochs = 100

    train_dataset = ThreeBodyDataset(
        t_path='three_body/train_t.npy',
        X_path='three_body/train_x.npy',
        y_path='three_body/train_y.npy'
    )

    test_dataset = ThreeBodyDataset(
        t_path='three_body/test_t.npy',
        X_path='three_body/test_x.npy',
        y_path='three_body/test_y.npy'
    )

    train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0
    )

    test_loader = Data.DataLoader(
        dataset=test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0
    )

    print(device)

    model = RIM(args).to(device)

    batch_size = 128
    num_epochs = 100

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_epochs * len(train_loader) / batch_size, eta_min = 0.00005, last_epoch = -1)

    criterion = nn.MSELoss()

    train_loss = []
    test_loss = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        batch_loss = 0
        model.train()
        for i, (_, x, y) in enumerate(tqdm(train_loader)):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output.view(-1), y.view(-1))
            loss.backward()
            batch_loss += loss.item()

            if i % batch_size == 0 and i != 0:
                print(' ', version, 'Epoch {}/{}'.format(epoch + 1, num_epochs))
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1, norm_type=2)
                optimizer.step()
                optimizer.zero_grad()
                train_loss.append(batch_loss / batch_size)
                print(' ', train_loss[-1])
                batch_loss = 0
                scheduler.step()

        epoch_loss = 0
        model.eval()
        with torch.no_grad():
            for i, (_, x, y) in enumerate(tqdm(test_loader)):
                x, y = x.to(device), y.to(device)
                output = model(x)
                loss = criterion(output.view(-1), y.view(-1))
                epoch_loss += loss.item()
            test_loss.append(epoch_loss / len(test_loader))
            print(' ', test_loss[-1])
            epoch_loss = 0

        torch.save(model, f'{version}.pkl')

        try:
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train.csv')
        except:
            print('Fail to save the file Train.csv')
            pd.DataFrame({'Train Loss': train_loss}).to_csv(f'{version}_Train_1.csv')

        try:
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test.csv')
        except:
            print('Fail to save the file Test.csv')
            pd.DataFrame({'Test Loss': test_loss}).to_csv(f'{version}_Test_1.csv')

    gc.collect()

if __name__ == '__main__':
    for hidden_size in [512, 1024, 2048]:
        file_version = f'RIM_Body_{hidden_size}'
        main(file_version, hidden_size)