import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import argoverse

import sys
import os
sys.path.append(os.path.dirname(__file__))

from CtsConv import *
from lstm_ctsconv import *

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ParticlesNetwork(nn.Module):
    def __init__(self, 
                 kernel_sizes = [4, 4, 4],
                 radius_scale = 40,
                 coordinate_mapping = 'ball_to_cube',
                 interpolation = 'linear',
                 use_window = True,
                 particle_radius = 0.5,
                 timestep = 1,
                 encoder_hidden_size = 21, 
                 correction_scale = 1 / 128., 
                 layer_channels = [16, 32, 32, 32, 2]
                 ):
        super(ParticlesNetwork, self).__init__()
        
        # init parameters
        
        self.kernel_sizes = kernel_sizes
        self.radius_scale = radius_scale
        self.coordinate_mapping = coordinate_mapping
        self.interpolation = interpolation
        self.use_window = use_window
        self.particle_radius = particle_radius
        self.timestep = timestep
        self.layer_channels = layer_channels
        self.filter_extent = np.float32(self.radius_scale * 6 *
                                        self.particle_radius)
        self.correction_scale = correction_scale
        
        self.encoder_hidden_size = encoder_hidden_size
        
        self.in_channel = 1 + 3 + self.encoder_hidden_size
        
        # create continuous convolution and fully-connected layers
        
        convs = []
        denses = []
        
        self.conv_fluid = CtsConv(in_channels = self.in_channel, 
                                  out_channels = self.layer_channels[0],
                                  kernel_sizes = self.kernel_sizes,
                                  radius = self.radius_scale)
        
        self.dense_fluid = nn.Linear(self.in_channel, self.layer_channels[0])
        
        # concat conv_obstacle, conv_fluid, dense_fluid
        in_ch = 2 * self.layer_channels[0] 
        for i in range(1, len(self.layer_channels)):
            out_ch = self.layer_channels[i]
            dense = nn.Linear(in_ch, out_ch)
            denses.append(dense)
            conv = CtsConv(in_channels = in_ch, 
                           out_channels = out_ch,
                           kernel_sizes = self.kernel_sizes,
                           radius = self.radius_scale)
            convs.append(conv)
            in_ch = self.layer_channels[i]
        
        self.convs = nn.ModuleList(convs)
        self.denses = nn.ModuleList(denses)
        
            
    def update_pos_vel(self, p0, v0, a):
        """Apply acceleration and integrate position and velocity.
        Assume the particle has constant acceleration during timestep.
        Return particle's position and velocity after 1 unit timestep."""
        
        dt = self.timestep
        v1 = v0 + dt * a
        p1 = p0 + dt * (v0 + v1) / 2
        return p1, v1

    def apply_correction(self, p0, p1, correction):
        """Apply the position correction
        p0, p1: the position of the particle before/after basic integration. """
        dt = self.timestep
        p_corrected = p1 + correction
        v_corrected = (p_corrected - p0) / dt
        return p_corrected, v_corrected
    
    def dense_forward(self, in_feats, dense_layer):
        flatten_in_feats = in_feats.reshape(
            in_feats.shape[0] * in_feats.shape[1], in_feats.shape[2])
        flatten_output = dense_layer(flatten_in_feats)
        return flatten_output.reshape(in_feats.shape[0], in_feats.shape[1], -1)

    def compute_correction(self, p, v, other_feats, fluid_mask):
        """Precondition: p and v were updated with accerlation"""

        # compute the extent of the filters (the diameter) and the fluid features
        filter_extent = torch.tensor(self.filter_extent)
        fluid_feats = [torch.ones_like(p[:,:, 0:1]), v]
        if not other_feats is None:
            fluid_feats.append(other_feats)
        fluid_feats = torch.cat(fluid_feats, -1)

        # compute the correction by accumulating the output through the network layers
        output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask)
        output_dense_fluid = self.dense_forward(fluid_feats, self.dense_fluid)
        
        feats = torch.cat((output_conv_fluid, output_dense_fluid), -1)
        # self.outputs = [feats]
        output = feats
        
        for conv, dense in zip(self.convs, self.denses):
            # pass input features to conv and fully-connected layers
            in_feats = F.relu(output)
            output_conv = conv(p, p, in_feats, fluid_mask)
            output_dense = self.dense_forward(in_feats, dense)
            
            # if last dim size of output from cur dense layer is same as last dim size of output
            # current output should be based off on previous output
            if output_dense.shape[-1] == output.shape[-1]:
                output = output_conv + output_dense + output
            else:
                output = output_conv + output_dense
            # self.outputs.append(output)

        # compute the number of fluid particle neighbors.
        # this info is used in the loss function during training.
        # TODO: test this block of code
        self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1
    
        # self.last_features = self.outputs[-2]

        # scale to better match the scale of the output distribution
        self.pos_correction = self.correction_scale * output
        
        return self.pos_correction
    
    def forward(self, inputs, states=None):
        """ inputs: 8 elems tuple
        p0_enc, v0_enc, p0, v0, a, feats, box, box_feats
        Computes 1 simulation timestep"""
        p0_enc, v0_enc, p0, v0, a, other_feats, fluid_mask = inputs
            
        if states is None:
            if other_feats is None:
                feats = v0_enc.reshape(*v0_enc.shape[:2], -1)
            else:
                feats = torch.cat((other_feats, v0_enc.reshape(*v0_enc.shape[:2], -1)), -1)
        else:
            if other_feats is None:
                feats = v0_enc.reshape(*v0_enc.shape[:2], -1)
                feats = torch.cat((states[0][...,3:], feats), -1)
            else:
                feats = torch.cat((other_feats, states[0][:-1], v0_enc.reshape(*v0_enc.shape[:2], -1)), -1)
        # print(feats.shape)

        p1, v1 = self.update_pos_vel(p0, v0, a)
        pos_correction = self.compute_correction(p1, v1, feats, fluid_mask)
        pos_correction = torch.cat([pos_correction, torch.zeros(*pos_correction.shape[:-1], 1, device=p1.device)], -1)
        # print(pos_correction)
        p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction)
        
        

        return p_corrected, v_corrected, (feats, None)


