"""
CNN feature extractor followed by a recurrent network.
"""
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.policy.rnn_sequencing import add_time_dimension
try:
    from ray.rllib.utils.torch_ops import atanh
except DeprecationWarning:
    from ray.rllib.utils.torch_utils import atanh
from ray.rllib.utils.numpy import SMALL_NUMBER
from qpth.qp import QPFunction, QPSolvers

from . import utils
from math import log


class CnnRnnDclf(RecurrentNetwork, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)

        utils.check_obs_space(obs_space, model_config)
        custom_model_config = model_config['custom_model_config']
        assert custom_model_config['obs_shape'][0][0] == custom_model_config['module_1']['conv_filters'][0][0]
        assert custom_model_config['obs_shape'][1][0] == custom_model_config['module_2']['fc_filters'][0]

        self._module_1 = utils.build_convnet(custom_model_config['module_1'])
        self._module_2 = utils.build_fcnet(custom_model_config['module_2'])
        self._module_3 = utils.build_fcnet(custom_model_config['module_3'])
        
        self._module_p = utils.build_fcnet(custom_model_config['module_p'])
        self._module_w = utils.build_fcnet(custom_model_config['module_w'])

        rnn_config = custom_model_config['module_4']
        rnn_type = rnn_config['type']
        rnn_in_channel = custom_model_config['module_3']['fc_filters'][-1]
        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            self._rnn = nn.LSTM(rnn_in_channel,
                               cell_size,
                               batch_first=not self.time_major)
            rnn_out_channel = cell_size
            self._use_output_fc = True
        elif rnn_type == 'ncp':
            import kerasncp as kncp
            from kerasncp.torch import LTCCell

            wiring = kncp.wirings.NCP(
                inter_neurons=rnn_config['inter_neurons'],
                command_neurons=rnn_config['command_neurons'],
                motor_neurons=num_outputs,
                sensory_fanout=rnn_config['sensory_fanout'],
                inter_fanout=rnn_config['inter_fanout'],
                recurrent_command_synapses=rnn_config['recurrent_command_synapses'],
                motor_fanin=rnn_config['motor_fanin'],
            )
            self._rnn = LTCCell(wiring, rnn_in_channel)
            self._use_output_fc = False
        elif rnn_type == 'cfc':
            from .cfc import Cfc

            self._rnn = Cfc(in_features=rnn_in_channel, 
                            hidden_size=rnn_config['hidden_size'],
                            out_feature=rnn_config['out_size'],
                            hparams=rnn_config,
                            return_sequences=True,
                            use_mixed=rnn_config['use_mixed'],
                            use_ltc=rnn_config['use_ltc'])
            rnn_out_channel = rnn_config['out_size']
            self._use_output_fc = True
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            self._action_fc = nn.Linear(rnn_out_channel, num_outputs)

        # NOTE: always use the same network architecture (w/o rnn) for value function
        if not model_config['vf_share_layers']:
            self._value_module_1 = utils.build_convnet(custom_model_config['module_1'])
            self._value_module_2 = utils.build_fcnet(custom_model_config['module_2'])
            self._value_module_3 = utils.build_fcnet(custom_model_config['module_3'])
        self._value_fc = nn.Linear(custom_model_config['module_3']['fc_filters'][-1], 1)

        if custom_model_config['module_4']['only_train_dclf']:
            # make non-trainable for all parameters
            for param in self.parameters():
                param.requires_grad = False
            
            # make dclf params trainable
            for param in self._module_p.parameters():
                param.requires_grad = True
            for param in self._module_w.parameters():
                param.requires_grad = True

    def forward(self, input_dict, state, seq_lens):

        # effect = input_dict['obs'][2:5] #pitch, roll, yaw
        effect = [1.0, 1.0, 0.2]
        # engine_power, pitch, yaw, roll
        input = [input_dict['obs'][1][:,7], input_dict['obs'][1][:,15], input_dict['obs'][1][:,19], input_dict['obs'][1][:,23], input_dict['obs'][1][:,0]]

        # feature extractor (non-recurrent)
        module_1_feat = self._module_1(input_dict['obs'][0][:,None,:,:])
        module_1_feat = module_1_feat.flatten(start_dim=1)
        module_2_feat = self._module_2(input_dict['obs'][1])

        merge_feat = torch.cat([module_1_feat, module_2_feat], dim=1)
        module_3_feat = self._module_3(merge_feat)
 
        if self.model_config['custom_model_config']['module_4']['use_dclf']:
            p = self._module_p(module_3_feat)
            w = self._module_w(module_3_feat)

        # convert to format for RNN inference
        feat = module_3_feat
        if isinstance(seq_lens, np.ndarray):
            seq_lens = torch.Tensor(seq_lens).int()
        max_seq_len = feat.shape[0] // seq_lens.shape[0]
        self.time_major = self.model_config.get('_time_major', False)
        rnn_inputs = add_time_dimension(
            feat,
            max_seq_len=max_seq_len,
            framework='torch',
            time_major=self.time_major,
        )

        # recurrent inference
        rnn_type = self.model_config['custom_model_config']['module_4']['type']
        if rnn_type == 'lstm':
            state = [v.unsqueeze(0) for v in state] # add sequence dim to state
            rnn_out, state = self._rnn(rnn_inputs, state)
            state = [v.squeeze(0) for v in state]
        elif rnn_type == 'ncp':
            self._rnn._params["sparsity_mask"] = self._rnn._params["sparsity_mask"].to(state[0].device)
            self._rnn._params["sensory_sparsity_mask"] = self._rnn._params["sensory_sparsity_mask"].to(state[0].device)
            state_ncp = state[0]
            out_ncp = []
            for seq_i in range(rnn_inputs.shape[1]):  # step-by-step inference with rnn cell (no-return-seq-state)
                rnn_inputs_t = rnn_inputs[:, seq_i]
                out_ncp_t, state_ncp = self._rnn(rnn_inputs_t, state_ncp)
                out_ncp.append(out_ncp_t)
            rnn_out = torch.stack(out_ncp, dim=1)
            state = [state_ncp]
        elif rnn_type == 'cfc':
            bsize, tsize = rnn_inputs.shape[:2]
            timespans = torch.ones((bsize, tsize)).to(rnn_inputs) * 0.1 # NOTE: not sure if this is correct
            if self._rnn.use_mixed:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], c_state=state[1], timespans=timespans)
            else:
                rnn_out, state = self._rnn(rnn_inputs, h_state=state[0], timespans=timespans)
                state = [state]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

        if self._use_output_fc:
            logits = self._action_fc(rnn_out)
        else:
            logits = rnn_out
        logits = torch.reshape(logits, [-1, self.num_outputs])

        if self.model_config['custom_model_config']['module_4']['use_dclf']:
            # Gaussian distribution
            # mean, log_std = torch.chunk(logits, 2, dim=-1)
            # mean = dCLF(mean, input, effect, p, w)
            # inputs = torch.cat([mean, log_std], dim=1)
            # import pdb; pdb.set_trace()

            # #beta distribution
            # # inputs = torch.clamp(logits, log(SMALL_NUMBER),
            # #                           -log(SMALL_NUMBER))
            # # inputs = torch.log(torch.exp(inputs) + 1.0) + 1.0
            # mean, var = torch.chunk(logits, 2, dim=-1)
            # # mean = 1/(1 + beta/alpha)
            # # var = (beta - 1)/(alpha + beta - 1)
            # #dCLF
            # mean = dCLF(mean, input, effect, p, w)
            # mean = torch.clamp(mean, 0.001*torch.ones_like(mean),
            #                           0.999*torch.ones_like(mean))   #bound mean in (0, 1)
            # var = torch.clamp(var, 0.001*torch.ones_like(mean),  mean*(1-mean) - 0.001*torch.ones_like(mean)) # bound var to be < mean(1-mean)
            # alpha = ((1 - mean)/var - 1/mean)*mean**2
            # beta = alpha*(1/mean - 1)
            # inputs = torch.cat([alpha, beta], dim=1)   # no need to convert back to logits if we change the call
            # # logits = torch.log(torch.exp(inputs - 1) - 1)

            logits_mode = self.model_config['custom_model_config']['module_4'].get('logits_mode', 'beta')
            control_range = [
                self.model_config['custom_model_config']['custom_action_dist_config']['low'],
                self.model_config['custom_model_config']['custom_action_dist_config']['high'],
            ]
            dclf_dt = self.model_config['custom_model_config']['module_4'].get('dclf_dt', 0.02)
            if logits_mode == 'beta':
                mean, raw_alpha = torch.chunk(logits, 2, dim=-1)
                mean = dCLF(mean, input, effect, p, w, dclf_dt)
                mean_squashed = (mean + control_range[0]) / (control_range[1] - control_range[0])
                mean_squashed = torch.clamp(mean_squashed, 0.001, 0.999)
                alpha = torch.log(torch.exp(raw_alpha) + 1.0) + 1.0
                beta = alpha * (1 / mean_squashed - 1)
                outputs = torch.cat([alpha, beta], dim=1)
            elif logits_mode == 'gaussian':
                mean, log_std = torch.chunk(logits, 2, dim=-1)
                mean = dCLF(mean, input, effect, p, w, dclf_dt)
                outputs = torch.cat([mean, log_std], dim=1)
            elif logits_mode == 'squashed_gaussian':
                def _squash(_raw_values):
                    _low, _high = control_range
                    _squashed = ((torch.tanh(_raw_values) + 1.0) / 2.0) * \
                        (_high - _low) + _low
                    return torch.clamp(_squashed, _low, _high)

                def _unsquash(_values):
                    _low, _high = control_range
                    _normed_values = (_values - _low) / (_high - _low) * 2.0 - 1.0
                    # Stabilize input to atanh.
                    _save_normed_values = torch.clamp(_normed_values, -1.0 + SMALL_NUMBER,
                                                     1.0 - SMALL_NUMBER)
                    _unsquashed = atanh(_save_normed_values)
                    return _unsquashed

                mean, log_std = torch.chunk(logits, 2, dim=-1)
                squashed_mean = _squash(mean)
                squashed_mean = dCLF(squashed_mean, input, effect, p, w, dclf_dt)
                mean = _unsquash(squashed_mean)
                outputs = torch.cat([mean, log_std], dim=1)
            else:
                raise NotImplementedError(f'Unrecognized logits mode {logits_mode}')
        else: # no dclf
            outputs = logits

        # value function inference
        if not self.model_config['vf_share_layers']:
            value_module_1_feat = self._value_module_1(input_dict['obs'][0][:,None,:,:])
            value_module_1_feat = value_module_1_feat.flatten(start_dim=1)
            value_module_2_feat = self._value_module_2(input_dict['obs'][1])

            value_merge_feat = torch.cat([value_module_1_feat, value_module_2_feat], dim=1)
            value_module_3_feat = self._value_module_3(value_merge_feat)

            self._value = self._value_fc(value_module_3_feat)
        else:
            self._value = self._value_fc(module_3_feat)
        
        return outputs, state

    @override(ModelV2)
    def get_initial_state(self):
        device = next(self.parameters()).device
        rnn_config = self.model_config['custom_model_config']['module_4']
        rnn_type = rnn_config['type']
        if rnn_type == 'lstm':
            cell_size = rnn_config['cell_size']
            return [
                torch.zeros((cell_size), dtype=torch.float32).to(device),
                torch.zeros((cell_size), dtype=torch.float32).to(device)
            ]
        elif rnn_type == 'ncp':
            return [torch.zeros((self._rnn.state_size), dtype=torch.float32).to(device)]
        elif rnn_type == 'cfc':
            state_size = (self._rnn.hidden_size)
            h_state = torch.zeros(state_size, dtype=torch.float32).to(device)
            if self._rnn.use_mixed:
                c_state = torch.zeros(state_size, dtype=torch.float32).to(device)
                return [h_state, c_state]
            else:
                return [h_state]
        else:
            raise NotImplementedError(f'Invalid RNN type {rnn_type}')

    def value_function(self):
        return self._value.squeeze(1)

    
def dCLF(logits, input, effect, p, w, dt):
    
    pitch_effect, roll_effect, yaw_effect = effect[0], effect[1], effect[2]
    engine_power, pitch, yaw, roll, throttle = input[0], input[1], input[2], input[3], input[4]
    w = nn.Sigmoid()(w) #positive
    p = nn.Sigmoid()(p)

    device = p[0].device    # 'cpu'
    nBatch = logits.shape[0]
    
    Q = Variable(torch.eye(5))
    Q = Q.unsqueeze(0).expand(nBatch, 5, 5).to(device)

    mask_00 = torch.zeros_like(Q)
    mask_00[:,0,0] = 1
    Q = Q * (1 - mask_00) + p[:,0][...,None,None] * mask_00

    mask_11 = torch.zeros_like(Q)
    mask_11[:,1,1] = 1
    Q = Q * (1 - mask_11) + p[:,1][...,None,None] * mask_11

    mask_22 = torch.zeros_like(Q)
    mask_22[:,2,2] = 1
    Q = Q * (1 - mask_22) + p[:,2][...,None,None] * mask_22

    mask_33 = torch.zeros_like(Q)
    mask_33[:,3,3] = 1
    Q = Q * (1 - mask_33) + p[:,3][...,None,None] * mask_33

    mask_44 = torch.zeros_like(Q)
    mask_44[:,4,4] = 1
    Q = Q * (1 - mask_44) + 10*p[:,4][...,None,None] * mask_44
    
    Tp = pitch * pitch_effect
    Tr = roll * roll_effect
    Ty = yaw * yaw_effect
    Tt = engine_power
    
    V = w[:,0]*(Tp - logits[:,0])**2 + w[:,1]*(Tr - logits[:,1])**2 + w[:,2]*(Ty - logits[:,2])**2 + w[:,3]*(Tt - logits[:,3])**2
    LfV = 0

    LgVup = 2*w[:,0]*(Tp - logits[:,0])
    LgVur = 2*w[:,1]*(Tr - logits[:,1])
    LgVuy = 2*w[:,2]*(Ty - logits[:,2])
    LgVut = 2*w[:,3]*(Tt - logits[:,3])

    LgVup = torch.reshape(LgVup, (nBatch, 1))
    LgVur = torch.reshape(LgVur, (nBatch, 1))
    LgVuy = torch.reshape(LgVuy, (nBatch, 1))
    LgVut = torch.reshape(LgVut, (nBatch, 1))
    LgVud = -torch.ones_like(LgVut)

    G = torch.cat([LgVup, LgVur, LgVuy, LgVut, LgVud], dim=1)
    G = torch.reshape(G, (nBatch, 1, 5)).to(device)
    h = (torch.reshape((-LfV - 10*V), (nBatch, 1))).to(device)
    
    # control bounds
    h1 = torch.ones_like(h)
    G1 = torch.zeros_like(G)
    G1[:,:,0] = torch.ones(nBatch, 1)
    G10 = torch.zeros_like(G)
    G10[:,:,0] = -torch.ones(nBatch, 1)

    G2 = torch.zeros_like(G)
    G2[:,:,1] = torch.ones(nBatch, 1)
    G20 = torch.zeros_like(G)
    G20[:,:,1] = -torch.ones(nBatch, 1)

    G3 = torch.zeros_like(G)
    G3[:,:,2] = torch.ones(nBatch, 1)
    G30 = torch.zeros_like(G)
    G30[:,:,2] = -torch.ones(nBatch, 1)

    G4 = torch.zeros_like(G)
    G4[:,:,3] = torch.ones(nBatch, 1)
    G40 = torch.zeros_like(G)
    G40[:,:,3] = -torch.ones(nBatch, 1)
    
    G = torch.cat([G,G1,G10,G2,G20,G3,G30,G4,G40], dim = 1).to(device)
    h = torch.cat([h,h1,h1, h1,h1, h1,h1, h1,h1], dim = 1).to(device)
    e = Variable(torch.Tensor()).to(device)
    
    q = torch.zeros(nBatch, 5).to(device)
    try:
        x = QPFunction(verbose=-1, solver = QPSolvers.PDIPM_BATCHED)(Q, q, G, h, e, e)
    except:
        print('Q has nan: {}'.format(torch.isnan(Q).any()))
        print('q has nan: {}'.format(torch.isnan(q).any()))
        print('G has nan: {}'.format(torch.isnan(G).any()))
        print('G1 has nan: {}'.format(torch.isnan(G1).any()))
        print('G10 has nan: {}'.format(torch.isnan(G10).any()))
        print('G2 has nan: {}'.format(torch.isnan(G2).any()))
        print('G20 has nan: {}'.format(torch.isnan(G20).any()))
        print('G3 has nan: {}'.format(torch.isnan(G3).any()))
        print('G30 has nan: {}'.format(torch.isnan(G30).any()))
        print('G4 has nan: {}'.format(torch.isnan(G4).any()))
        print('G40 has nan: {}'.format(torch.isnan(G40).any()))
        print('h has nan: {}'.format(torch.isnan(h).any()))
        print('h1 has nan: {}'.format(torch.isnan(h1).any()))
        print('e has nan: {}'.format(torch.isnan(e).any()))
        print('logits has nan: {}'.format(torch.isnan(logits).any()))
        print('input has nan: {}'.format(torch.isnan(torch.cat(input)).any()))
        print('p has nan: {}'.format(torch.isnan(p).any()))
        print('w has nan: {}'.format(torch.isnan(w).any()))
        print('pitch_effect has nan: {}'.format(np.isnan(pitch_effect).any()))
        print('roll_effect has nan: {}'.format(np.isnan(roll_effect).any()))
        print('yaw_effect has nan: {}'.format(np.isnan(yaw_effect).any()))
        print('engine_power has nan: {}'.format(torch.isnan(engine_power).any()))
        print('pitch has nan: {}'.format(torch.isnan(pitch).any()))
        print('yaw has nan: {}'.format(torch.isnan(yaw).any()))
        print('roll has nan: {}'.format(torch.isnan(roll).any()))
        print('throttle has nan: {}'.format(torch.isnan(throttle).any()))

    x = x[:,0:4]
    
    pitch = torch.reshape(pitch, (nBatch, 1))
    roll = torch.reshape(roll, (nBatch, 1))
    yaw = torch.reshape(yaw, (nBatch, 1))
    engine_power = torch.reshape(engine_power, (nBatch, 1))
    throttle = torch.reshape(throttle, (nBatch, 1))

    act = dt*x + torch.cat([pitch, roll, yaw, throttle], dim=1) # engine power or throttle?

    return act

