import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from torch.autograd import Variable
# from utils import my_softmax, get_offdiag_indices, gumbel_softmax

import logging
import itertools
import torch
import torch.nn as nn
from torch.nn.functional import softmax, relu
from torch.nn import Parameter
import numpy as np
import torch
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable

def mlp(input_dim, mlp_dims, last_relu=False):
    layers = []
    mlp_dims = [input_dim] + mlp_dims
    for i in range(len(mlp_dims) - 1):
        layers.append(nn.Linear(mlp_dims[i], mlp_dims[i + 1]))
        if i != len(mlp_dims) - 2 or last_relu:
            layers.append(nn.ReLU())
    net = nn.Sequential(*layers)
    return net




def encode_onehot(labels, classes=None):
    if classes:
        classes = [x for x in range(classes)]
    else:
        classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def get_triu_indices(num_nodes):
    """Linear triu (upper triangular) indices."""
    ones = torch.ones(num_nodes, num_nodes)
    eye = torch.eye(num_nodes, num_nodes)
    triu_indices = (ones.triu() - eye).nonzero().t()
    triu_indices = triu_indices[0] * num_nodes + triu_indices[1]
    return triu_indices


def get_tril_indices(num_nodes):
    """Linear tril (lower triangular) indices."""
    ones = torch.ones(num_nodes, num_nodes)
    eye = torch.eye(num_nodes, num_nodes)
    tril_indices = (ones.tril() - eye).nonzero().t()
    tril_indices = tril_indices[0] * num_nodes + tril_indices[1]
    return tril_indices


def get_offdiag_indices(num_nodes):
    """Linear off-diagonal indices."""
    ones = torch.ones(num_nodes, num_nodes)
    eye = torch.eye(num_nodes, num_nodes)
    offdiag_indices = (ones - eye).nonzero().t()
    offdiag_indices = offdiag_indices[0] * num_nodes + offdiag_indices[1]
    return offdiag_indices


def get_triu_offdiag_indices(num_nodes):
    """Linear triu (upper) indices w.r.t. vector of off-diagonal elements."""
    triu_idx = torch.zeros(num_nodes * num_nodes)
    triu_idx[get_triu_indices(num_nodes)] = 1.
    triu_idx = triu_idx[get_offdiag_indices(num_nodes)]
    return triu_idx.nonzero()


def get_tril_offdiag_indices(num_nodes):
    """Linear tril (lower) indices w.r.t. vector of off-diagonal elements."""
    tril_idx = torch.zeros(num_nodes * num_nodes)
    tril_idx[get_tril_indices(num_nodes)] = 1.
    tril_idx = tril_idx[get_offdiag_indices(num_nodes)]
    return tril_idx.nonzero()


def get_minimum_distance(data):
    data = data[:, :, :, :2].transpose(1, 2)
    data_norm = (data ** 2).sum(-1, keepdim=True)
    dist = data_norm + \
           data_norm.transpose(2, 3) - \
           2 * torch.matmul(data, data.transpose(2, 3))
    min_dist, _ = dist.min(1)
    return min_dist.view(min_dist.size(0), -1)


def edge_accuracy(preds, target):
    _, preds = preds.max(-1)
    correct = preds.float().data.eq(
        target.float().data.view_as(preds)).cpu().sum()
    return np.float(correct) / (target.size(0) * target.size(1))

class MLP(nn.Module):
    """Two-layer fully-connected ELU net with batch norm."""

    def __init__(self, n_in, n_hid, n_out, do_prob=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc2 = nn.Linear(n_hid, n_out)
        self.bn = nn.BatchNorm1d(n_out)
        self.dropout_prob = do_prob

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def batch_norm(self, inputs):
        x = inputs.view(inputs.size(0) * inputs.size(1), -1)
        x = self.bn(x)
        return x.view(inputs.size(0), inputs.size(1), -1)

    def forward(self, inputs):
        # Input shape: [num_sims, num_things, num_features]
        x = F.elu(self.fc1(inputs))
        x = F.dropout(x, self.dropout_prob, training=self.training)
        x = F.elu(self.fc2(x))
        return self.batch_norm(x)

class MLPEncoder(nn.Module):
    def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True):
        super(MLPEncoder, self).__init__()

        self.factor = factor
        self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob)
        self.mlp2 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
        self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob)
        if self.factor:
            self.mlp4 = MLP(n_hid * 3, n_hid, n_hid, do_prob)
            print("Using factor graph MLP encoder.")
        else:
            self.mlp4 = MLP(n_hid * 2, n_hid, n_hid, do_prob)
            print("Using MLP encoder.")
        self.fc_out = nn.Linear(n_hid, n_out)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def edge2node(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        incoming = torch.matmul(rel_rec.t(), x)
        return incoming / incoming.size(1)

    def node2edge(self, x, rel_rec, rel_send):
        # NOTE: Assumes that we have the same graph across all samples.
        receivers = torch.matmul(rel_rec, x)
        senders = torch.matmul(rel_send, x)
        edges = torch.cat([senders, receivers], dim=2)
        return edges

    def forward(self, inputs, rel_rec, rel_send):
        # Input shape: [num_sims, num_atoms, num_timesteps, num_dims]
        x = inputs.view(inputs.shape[0], inputs.shape[1], -1)
        # New shape: [num_sims, num_atoms, num_timesteps*num_dims]

        x = node_embeddings = self.mlp1(x)  # 2-layer ELU net per node

        x = self.node2edge(x, rel_rec, rel_send)
        x = self.mlp2(x)
        x_skip = x

        if self.factor:
            x = self.edge2node(x, rel_rec, rel_send)
            x = self.mlp3(x)
            x = self.node2edge(x, rel_rec, rel_send)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)
        else:
            x = self.mlp3(x)
            x = torch.cat((x, x_skip), dim=2)  # Skip connection
            x = self.mlp4(x)

        return node_embeddings, self.fc_out(x)



class RNNDecoder(nn.Module):
    """Recurrent decoder module."""

    def __init__(self, n_in_node, edge_types, n_hid,
                 do_prob=0., skip_first=0):
        super(RNNDecoder, self).__init__()
        self.msg_fc1 = nn.ModuleList(
            [nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)])
        self.msg_fc2 = nn.ModuleList(
            [nn.Linear(n_hid, n_hid) for _ in range(edge_types)])
        self.msg_out_shape = n_hid
        # self.skip_first_edge_type = skip_first
        self.skip_first = skip_first

        self.enc_x = nn.Linear(n_in_node, n_hid//2, bias=False)
        self.enc_vel = nn.Linear(n_in_node, n_hid//2, bias=False)
        self.enc_input = nn.Linear(n_in_node, n_hid, bias=False)

        self.dec_Energy = nn.Linear(n_hid, n_hid, bias=False)
        # self.pred_ball = nn.Linear(2 * n_hid, 2 * n_hid, bias=False)
        self.pred_ball = nn.Linear(n_hid, n_hid, bias=False)

        self.hidden_r = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_i = nn.Linear(n_hid, n_hid, bias=False)
        self.hidden_h = nn.Linear(n_hid, n_hid, bias=False)

        self.input_r = nn.Linear(n_hid, n_hid, bias=True)
        self.input_i = nn.Linear(n_hid, n_hid, bias=True)
        self.input_n = nn.Linear(n_hid, n_hid, bias=True)

        self.out_fc1 = nn.Linear(n_hid, n_hid)
        self.out_fc2 = nn.Linear(n_hid, n_hid)
        self.out_fc3 = nn.Linear(n_hid, n_in_node)

        print('Using learned recurrent interaction net decoder.')

        self.dropout_prob = do_prob

    def single_step_forward(self, inputs, rel_rec, rel_send,
                            rel_type, hidden):

        # node2edge
        receivers = torch.matmul(rel_rec, hidden)
        senders = torch.matmul(rel_send, hidden)
        pre_msg = torch.cat([senders, receivers], dim=-1)

        all_msgs = Variable(torch.zeros(pre_msg.size(0), pre_msg.size(1),
                                        self.msg_out_shape))
        if inputs.is_cuda:
            all_msgs = all_msgs.to(inputs.get_device())

        if self.skip_first:
            start_idx = self.skip_first
            norm = float(len(self.msg_fc2)) - start_idx
        else:
            start_idx = 0
            norm = float(len(self.msg_fc2))

        # Run separate MLP for every edge type
        # NOTE: To exlude one edge type, simply offset range by 1
        for i in range(start_idx, len(self.msg_fc2)):
            msg = torch.tanh(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, p=self.dropout_prob)
            msg = torch.tanh(self.msg_fc2[i](msg))
            msg = msg * rel_type[:, :, i:i + 1]
            all_msgs += msg #/ norm

        agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2,
                                                                        -1)
        agg_msgs = agg_msgs.contiguous() / inputs.size(2)  # Average

        # GRU-style gated aggregation
        r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs))
        i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs))
        n = torch.tanh(self.input_n(inputs) + r * self.hidden_h(agg_msgs))
        hidden = (1 - i) * n + i * hidden

        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(hidden)), p=self.dropout_prob)

        ############### Integrating the inter-agent interaction constraint into the Energy Module.
        # pred_agent = pred[:, :2, :]
        # Energy = self.dec_Energy(pred_agent)
        
        # ball_slice = []
        
        
        # ball_slice.append(torch.autograd.grad(Energy[:, 0:1, :].sum(), inputs, retain_graph=True)[0][:, 0:1, :])
        # ball_slice.append(torch.autograd.grad(Energy[:, 1:2, :].sum(), inputs, retain_graph=True)[0][:, 1:2, :])
        
        # p_ball = torch.cat(ball_slice, dim=-1)
        # ball = F.dropout(F.relu(self.pred_ball(p_ball)), p=self.dropout_prob).reshape(p_ball.shape[0], 2, -1)


        # pred = F.dropout(F.relu(self.out_fc2(Energy)), p=self.dropout_prob)
        # pred = torch.cat([pred, ball], dim=1)
        
        ############### Using the inter-agent interaction constraint directly as part of the loss function.
        Energy = self.dec_Energy(pred)
        pred = F.dropout(F.relu(self.out_fc2(Energy)), p=self.dropout_prob)

        ###############
        pred_vel = self.out_fc3(pred)

        return pred_vel, hidden, Energy

    def forward(self, data, data_vel, rel_type, rel_rec, rel_send, pred_steps=1,
                burn_in=False, burn_in_steps=1, dynamic_graph=False,
                encoder=None, temp=None):

        # RFM_NEW: batch_size, num_atoms, hidden_dim
        # NRI: batch_size, num_atoms, timesteps, hidden_dim
        inputs = data.contiguous()
        inputs_vel = data_vel.contiguous()
        # rel_type has shape:
        # [batch_size, num_atoms*(num_atoms-1), num_edge_types]

        hidden = Variable(
            torch.zeros(inputs.size(0), inputs.size(1), self.msg_out_shape))
        if inputs.is_cuda:
            hidden = hidden.to(data.get_device())

        pred_all = []
        pred_vel_all = []
        pred_obs = []
        pred_obs_vel = []
        intra_loss = []
        inter_loss = []

        delta_t = 0.25
        use_intra = True
        use_inter = True

        intra=torch.zeros_like(inputs[:, :, 0, :])

        obs = torch.zeros(inputs.size(0), inputs.size(2)-1, inputs.size(1), inputs.size(3))


        if burn_in:


            for step in range(burn_in_steps - 1):
                ins_x = inputs[:, :, step, :].detach()
                ins_vel = inputs_vel[:, :, step, :].detach()

                ins_x.requires_grad = True
                ins_vel.requires_grad = True

                encode_ins = self.enc_x(ins_x)
                encode_ins_vel = self.enc_vel(ins_vel)
                ins = torch.cat([encode_ins, encode_ins_vel], -1)

                pred_vel, hidden, Energy = self.single_step_forward(ins, rel_rec, rel_send,
                                                        rel_type, hidden)
                
                pred_x = ins_x + pred_vel * delta_t
                pred_x[:, :, 2:] = pred_vel[:, :, 2:]
                pred_obs.append(pred_x)
                pred_obs_vel.append(pred_vel)

                ################## intra-agent motion constraint
                # if use_intra:
                #     a = - torch.autograd.grad(Energy.sum(), ins_x, retain_graph=True, create_graph=True)[0]
                #     intra = 2 * delta_t + delta_t * delta_t * torch.autograd.grad(a.sum(), ins_vel, retain_graph=True, create_graph=True)[0] - \
                #         torch.autograd.grad(pred_x.sum(), ins_vel, retain_graph=True, create_graph=True)[0]
                #     intra_loss.append(intra)


                ################## inter-agent interaction constraint
                if use_inter:
                    inter_sum = torch.autograd.grad(Energy.sum(), ins_x, retain_graph=True, create_graph=True)[0].sum(dim=1)
                    inter_loss.append(inter_sum)


            obs = torch.stack(pred_obs, dim=1)
            obs_vel = torch.stack(pred_obs_vel, dim=1)

            assert obs.shape[1] == burn_in_steps - 1 

            inputs = inputs[:, :, -1, :]
            inputs_vel = inputs_vel[:, :, -1, :]


        elif len(inputs.shape) == 4:
            inputs = inputs[:, :, -1, :]
            inputs_vel = inputs_vel[:, :, -1, :]


        for step in range(pred_steps):
            if step == 0:
                ins_x = inputs.detach()
                ins_vel = inputs_vel.detach()

                ins_x.requires_grad = True
                ins_vel.requires_grad = True

                encode_ins = self.enc_x(ins_x)
                encode_ins_vel = self.enc_vel(ins_vel)
                ins = torch.cat([encode_ins, encode_ins_vel], -1)

            else:
                ins_x = pred_all[step - 1].detach()
                ins_vel = pred_vel_all[step - 1].detach()
            
                ins_x.requires_grad = True
                ins_vel.requires_grad = True
            
                encode_ins = self.enc_x(ins_x)
                encode_ins_vel = self.enc_vel(ins_vel)
                ins = torch.cat([encode_ins, encode_ins_vel], -1)

            if dynamic_graph: # and step >= burn_in_steps:
                # # NOTE: Assumes burn_in_steps = args.timesteps
                    # data[:, :, step - burn_in_steps:step, :].contiguous(),
                # rel_type = gumbel_softmax(logits, tau=temp, hard=True)

                _, logits = encoder(ins, rel_rec, rel_send)
                tmp_rel_type = F.softmax(logits.reshape(-1, self.num_humans, self.num_humans-1), dim=-1)
                tmp_rel_type = tmp_rel_type.reshape(-1, self.num_humans * (self.num_humans-1))

                pred_vel, hidden, Energy = self.single_step_forward(ins, rel_rec, rel_send,
                                                        torch.stack((tmp_rel_type, rel_type[..., 1]), dim=-1),
                                                        hidden)
            else:
                pred_vel, hidden, Energy = self.single_step_forward(ins, rel_rec, rel_send,
                                                        rel_type, hidden)

            pred_x = ins_x + pred_vel * delta_t
            pred_x[:, :, 2:] = pred_vel[:, :, 2:]
            pred_all.append(pred_x)
            pred_vel_all.append(pred_vel)

            ################## intra-agent motion constraint
            if use_intra:
                a = - torch.autograd.grad(Energy.sum(), ins_x, retain_graph=True, create_graph=True)[0]
                intra = 2 * delta_t + delta_t * delta_t * torch.autograd.grad(a.sum(), ins_vel, retain_graph=True, create_graph=True)[0] - \
                        torch.autograd.grad(pred_x.sum(), ins_vel, retain_graph=True, create_graph=True)[0]
                intra_loss.append(intra)


            ################## inter-agent interaction constraint
            if use_inter:
                inter_sum = torch.autograd.grad(Energy.sum(), ins_x, retain_graph=True, create_graph=True)[0].sum(dim=1)
                inter_loss.append(inter_sum)


        preds = torch.stack(pred_all, dim=1)
        preds_vel = torch.stack(pred_vel_all, dim=1)


        ################## intra-agent motion constraint
        if use_intra:
            assert intra.sum() != 0.0
            intra_loss = torch.stack(intra_loss, dim=1)
        else:
            assert intra.sum() == 0.0
            intra_loss = torch.zeros_like(intra)

        ################### inter-agent interaction constraint
        if use_inter:
            inter_loss = torch.stack(inter_loss, dim=1)
        else:
            inter_loss = torch.zeros_like(intra)



        
        
        

        return preds.transpose(1, 2).contiguous(), preds_vel.transpose(1, 2).contiguous(), intra_loss, inter_loss, obs
