from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import numpy as np
from numpy.lib.shape_base import take_along_axis
import gym

import torch
from torch import nn as nn
from torch.nn import functional as F


from DotmapUtils import get_required_argument
from config.utils import swish, get_affine_params

from env.pointbot.pointbot_const import *
from env.pointbot.gap import get_gap

TORCH_DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("using TORCH_DEVICE", TORCH_DEVICE)

class PtModel(nn.Module):

    def __init__(self, ensemble_size, in_features, out_features):
        super().__init__()

        self.num_nets = ensemble_size

        self.in_features = in_features
        self.out_features = out_features

        self.lin0_w, self.lin0_b = get_affine_params(ensemble_size, in_features, 200)

        self.lin1_w, self.lin1_b = get_affine_params(ensemble_size, 200, 200)

        self.lin2_w, self.lin2_b = get_affine_params(ensemble_size, 200, 200)

        self.lin3_w, self.lin3_b = get_affine_params(ensemble_size, 200, 200)

        self.lin4_w, self.lin4_b = get_affine_params(ensemble_size, 200, out_features)

        self.inputs_mu = nn.Parameter(torch.zeros(in_features), requires_grad=False)
        self.inputs_sigma = nn.Parameter(torch.zeros(in_features), requires_grad=False)

        self.max_logvar = nn.Parameter(torch.ones(1, out_features // 2, dtype=torch.float32) / 2.0)
        self.min_logvar = nn.Parameter(- torch.ones(1, out_features // 2, dtype=torch.float32) * 10.0)

    def compute_decays(self):

        lin0_decays = 0.000025 * (self.lin0_w ** 2).sum() / 2.0
        lin1_decays = 0.00005 * (self.lin1_w ** 2).sum() / 2.0
        lin2_decays = 0.000075 * (self.lin2_w ** 2).sum() / 2.0
        lin3_decays = 0.000075 * (self.lin3_w ** 2).sum() / 2.0
        lin4_decays = 0.0001 * (self.lin4_w ** 2).sum() / 2.0

        return lin0_decays + lin1_decays + lin2_decays + lin3_decays + lin4_decays

    def fit_input_stats(self, data):

        mu = np.mean(data, axis=0, keepdims=True)
        sigma = np.std(data, axis=0, keepdims=True)
        sigma[sigma < 1e-12] = 1.0

        self.inputs_mu.data = torch.from_numpy(mu).to(TORCH_DEVICE).float()
        self.inputs_sigma.data = torch.from_numpy(sigma).to(TORCH_DEVICE).float()

    def forward(self, inputs, ret_logvar=False):

        # Transform inputs
        inputs = (inputs - self.inputs_mu) / self.inputs_sigma

        inputs = inputs.matmul(self.lin0_w) + self.lin0_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin1_w) + self.lin1_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin2_w) + self.lin2_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin3_w) + self.lin3_b
        inputs = swish(inputs)

        inputs = inputs.matmul(self.lin4_w) + self.lin4_b

        mean = inputs[:, :, :self.out_features // 2]

        logvar = inputs[:, :, self.out_features // 2:]
        logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
        logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)

        if ret_logvar:
            return mean, logvar

        return mean, torch.exp(logvar)


class PointBotConfigModule:
    ENV_NAME = "MBRLPointBot-v0"
    TASK_HORIZON = 100
    NTRAIN_ITERS = 100
    NROLLOUTS_PER_ITER = 1
    PLAN_HOR = 15
    MODEL_IN, MODEL_OUT = 6, 4
    VALUE_IN, VALUE_OUT = 4, 1
    
    ALPHA_THRESH = 3
    HAS_CONSTRAINTS = True
    BETA_THRESH = 1
    
    NUM_DEMOS = 100
    DEMO_LOW_COST = None
    DEMO_HIGH_COST = None
    DEMO_LOAD_PATH = None
    LOAD_SAMPLES = False
    GYM_ROBOTICS = False
    SS_BUFFER_SIZE = 20000
    
    VAL_BUFFER_SIZE = 200000
    USE_VALUE = True

    maneuver = 'straight' 
    #maneuver = 'barcelona_c14'

    def __init__(self):
        self.ENV = gym.make(self.ENV_NAME)
        self.ENV.set_mode(mode=1)
        self.ENV.set_maneuver(PointBotConfigModule.maneuver, TORCH_DEVICE)

        # trick to have access to ENV functions in the statimethods below
        global ENV
        ENV = self.ENV
        self.NN_TRAIN_CFG = {"epochs": 5}
        self.OPT_CFG = {
            "Random": {
                "popsize": 2000
            },
            "CEM": {
                "popsize": 400,
                "num_elites": 40,
                "max_iters": 5,
                "alpha": 0.1
            }
        }

    @staticmethod
    def obs_preproc(obs):
        '''
        prepares the input to the NN
        '''    
        return obs[:,:4] # discard the gap

    @staticmethod
    def targ_proc(obs, next_obs):
        '''
        obs is the one from the environment and next obs also but st+1

        In all the envs except this one this was: next_obs - obs

        In this one is concat(speedx (st+1) , delta of the rest)
        '''
        return next_obs[:,:4] - obs[:,:4]

    @staticmethod
    def obs_postproc(obs, pred):
        '''
        all envs was obs + pred (prediction from the dynamics model)

        Here is: 
            concat [ pred(0), obs_preproc(obs)[1:] + pred[1:] ]

        Note: obs is not raw obs but proc_obs = self.obs_preproc(obs)
        '''    
        
        next_obs_pred = obs[:,:4] + pred
        
        x = next_obs_pred[:, 0:1]
        y = next_obs_pred[:, 2:3]
    
        if isinstance(obs, np.ndarray):
            points = np.hstack([x, y]) 
            gap, _ = get_gap(points, ENV.racing_line)
            gap = gap.reshape(-1,1)
            return np.concatenate([next_obs_pred, gap], axis=1)
        else:
            points = torch.hstack([x, y]) 
            gap, _ = get_gap(points, ENV.racing_line_torch)
            gap = gap.reshape(-1,1)
            return torch.cat([next_obs_pred, gap], dim=1)

    # @staticmethod
    # def obs_preproc(obs):
    #     return obs[:,:3] # discard the gap

    # @staticmethod
    # def obs_postproc(obs, pred):
    #     return obs + pred

    # @staticmethod
    # def targ_proc(obs, next_obs):
    #     return next_obs - obs

    @staticmethod
    def obs_cost_fn(obs):
        if isinstance(obs, np.ndarray):
            target = np.tile(GOAL_STATE, (len(obs), 1))
            if LINE_FOLLOWER_MODE:
                return ENV.step_cost(obs, None)
            elif not HARD_MODE:
                return np.linalg.norm(np.subtract(target, obs), axis=1)
            else:
                return (np.linalg.norm(np.subtract(target, obs), axis=1) > GOAL_THRESH).astype(np.float32)
        else:
            target = np.tile(GOAL_STATE, (obs.shape[0], 1))
            target = torch.tensor(target, device=TORCH_DEVICE)
            if LINE_FOLLOWER_MODE:
                return ENV.step_cost(obs, None)
            elif not HARD_MODE:
                return torch.norm(torch.subtract(target, obs), dim=1)
            else:
                r = torch.norm(torch.subtract(target, obs), dim=1) > GOAL_THRESH
                return r.type(torch.float)

    @staticmethod
    def ac_cost_fn(acs):
        ACT_PENALIZATION_CST = 0.00
        #ACT_PENALIZATION_CST = 0.01 
        if isinstance(acs, np.ndarray):
            return ACT_PENALIZATION_CST * np.sum(np.square(acs), axis=1)
        else:
            #return 0.0 * tf.reduce_sum(tf.square(acs), axis=1)
            return ACT_PENALIZATION_CST * torch.sum(torch.square(acs), axis=1)

    def nn_constructor(self, model_init_cfg):

        ensemble_size = get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size")

        load_model = model_init_cfg.get("load_model", False)

        assert load_model is False, 'Has yet to support loading model'

        model = PtModel(ensemble_size,
                        self.MODEL_IN, self.MODEL_OUT * 2).to(TORCH_DEVICE)
        # * 2 because we output both the mean and the variance

        model.optim = torch.optim.Adam(model.parameters(), lr=0.001)

        return model


    # def nn_constructor(self, model_init_cfg):
    #     model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
    #         name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
    #         sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
    #         model_dir=model_init_cfg.get("model_dir", None)
    #     ))
    #     if not model_init_cfg.get("load_model", False):
    #         model.add(FC(500, input_dim=self.MODEL_IN, activation='swish', weight_decay=0.0001))
    #         model.add(FC(500, activation='swish', weight_decay=0.00025))
    #         model.add(FC(500, activation='swish', weight_decay=0.00025))
    #         model.add(FC(self.MODEL_OUT, weight_decay=0.0005))
    #     model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
    #     return model


CONFIG_MODULE = PointBotConfigModule
