import numpy as np
from typing import Sequence
from tqdm import tqdm
import random 

import flax.linen as nn
from flax.training import train_state
import jax 
import jax.numpy as jnp
import jax.random as jrandom
from jax import vmap, grad, jit 
import optax 

from opelab.core.baseline import Baseline
from opelab.core.data import DataType, to_numpy
from opelab.core.mlp import MLP
from opelab.core.policy import Policy
import pickle 

import os 

class MBR(Baseline):
    
    def __init__(self, state_dim: int, savepath = 'models/hopper/model_based.pkl', N: int = 100, tau=0.05, horizon:int= 100, lr: float=3e-4, layers: Sequence[int] = [600,600], epochs:int=30, batch_size: int = 1024, verbose:int=0, seed:int=0, terminated_fn=None):
        self.terminate_fn = terminated_fn
        self.lr = lr 
        self.epochs = epochs 
        self.N = N
        self.batch_size = batch_size
        self.horizon = horizon
        self.verbose = verbose
        self.seed = seed
        self.tau = tau
        self.optimizer = optax.adam(lr)
        self.model = MLP(layers+[state_dim,], nn.relu, output_activation=lambda s: s)
        self.reward_model = MLP(layers+[1,], nn.relu, output_activation=lambda s: s)
        self.done_model = MLP(layers+[1,], nn.relu, output_activation=lambda s: s)
        self.Q = MLP(layers + [1,], nn.relu, output_activation=lambda s: s)
        self.trained = False
        self.savepath = savepath
       
       
        def predict_w_fn(params, states, actions):
            xus = jnp.concatenate([states, actions], axis=-1)
            return vmap(self.model.apply, in_axes=(None, 0))(params, xus)
        
        def predict_r_w_fn(params, states, actions):
            xus = jnp.concatenate([states, actions], axis=-1)
            return vmap(self.reward_model.apply, in_axes=(None,0))( params, xus)
        
        def predict_d_w_fn(params, states, actions):
            xus = jnp.concatenate([states, actions], axis=-1)
            predicted_logit = vmap(self.done_model.apply, in_axes=(None,0))( params, xus)
            predicted_p = jax.lax.logistic(predicted_logit)
            return predicted_p
        
        def soft_update(x, y):
            return jax.tree_util.tree_map(lambda a, b: self.tau*b + (1-self.tau)*a, x, y)
        
        self.soft_update =  jit(soft_update)
    
        
        
        def sample_from_p(params, key, states, actions):
            pred_p = self.predict_d_w_fn(params, states, actions)
            
            outputs = jrandom.bernoulli(key, pred_p)
            return outputs
        
        def predict_q_w_fn(params, states, actions):
            xus = jnp.concatenate([states, actions], axis=-1)
            return vmap(self.Q.apply, (None, 0))(params, xus)
        
        def model_loss(params, states, actions, next_states):
            
            pred_next_state = predict_w_fn(params, states, actions)
            
            loss = 0.5 * jnp.sum(jnp.square(pred_next_state - next_states))
            return loss
        
        def done_loss_fn(params, states, actions, dones, weights):
                pred_done = predict_d_w_fn(params, states, actions)
                
                epsilon=1e-5
                print(weights.shape, (dones*jnp.log(epsilon + pred_done) + (1-dones)*jnp.log(epsilon + 1-pred_done)).shape)
                log_loss = -1*jnp.mean(weights*(dones*jnp.log(epsilon + pred_done) + (1-dones)*jnp.log(epsilon + 1-pred_done)))
                return log_loss
        
        def train_d_fn(mlp_state, states, actions, dones, weights):
            
            
        
                
                
            
            loss, grads = jax.value_and_grad(done_loss_fn, 0)(mlp_state.params, states, actions, dones, weights)
            
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss 
        
        
        def train_fn(mlp_state, states, actions, next_states):
            
            
            
            loss, grads = jax.value_and_grad(model_loss, (0))(mlp_state.params, states, actions, next_states)
            
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss 
        def reward_loss_fn(params, states, actions, rewards):
                pred_rewards = predict_r_w_fn(params, states, actions)
                
                loss = 0.5 *jnp.mean(jnp.square(pred_rewards - rewards))
                
                return loss 
        
        def train_reward_loss(mlp_state, states, actions, rewards):
            
            loss, grads = jax.value_and_grad(reward_loss_fn, 0)(mlp_state.params, states, actions, rewards)
            
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss
        
        def fqe_loss_fn(params, q1_target_params, q2_target_params, states, actions, rewards, next_states, next_actions, dones, gamma, clip):
                
                current_q = predict_w_fn(params, states, actions)
                next_q1 = jax.lax.stop_gradient(predict_w_fn(q1_target_params, next_states, next_actions))
                
                
                next_q2 = jax.lax.stop_gradient(predict_w_fn(q2_target_params, next_states, next_actions))
                
                
                next_q = jnp.minimum(next_q1, next_q2)
                target = rewards + gamma*(1-dones)*next_q
                target = jnp.clip(target, -1*clip, clip)
                loss = jnp.mean(jnp.square(current_q - target))
                print(target.shape, current_q.shape)
                return loss 
        
        def train_q_fn(mlp_state, q1_target_params, q2_target_params, states, actions, rewards, next_states, next_actions, dones, gamma, clip):
            

            loss, grads = jax.value_and_grad(fqe_loss_fn, argnums=(0))(mlp_state.params, q1_target_params, q2_target_params, states, actions, rewards, next_states, next_actions, dones, gamma, clip)
            
            mlp_state = mlp_state.apply_gradients(grads=grads)
            return mlp_state, loss 
        
        self.predict_w_fn = jit(predict_w_fn)
        self.predict_r_w_fn = jit(predict_r_w_fn)
        self.predict_d_w_fn = jit(predict_d_w_fn)
        self.sample_d = jit(sample_from_p)
        self.predict_q_w_fn = jit(predict_q_w_fn)
        self.model_loss = jit(model_loss)
        self.train_fn = jit(train_fn)
        self.train_reward_fn = jit(train_reward_loss)
        self.train_q_fn = jit(train_q_fn)
        self.train_d_fn = jit(train_d_fn)
    
    def train_dynamics_network(self, data: DataType, target: Policy, behavior: Policy):
        _, states, actions, _, next_states, rewards, policy_ratio, terminals = to_numpy(data, target, behavior, return_terminals=True)
        
        rew_min, rew_max = np.min(rewards), np.max(rewards)
        state_min, state_max = np.min(states, axis=0), np.max(states, axis=0)
        done_probability = terminals.mean()
        
        weights = 1 / (done_probability*terminals + (1-done_probability)*(1-terminals))
        
        xus = jnp.concatenate([states, actions], axis=-1)
        params = self.model.init(jax.random.PRNGKey(self.seed), xus[0])
        mlp_state =train_state.TrainState.create(
            apply_fn = self.predict_w_fn,
            params = params ,
            tx = self.optimizer
        )
        
        rew_params = self.reward_model.init(jax.random.PRNGKey(self.seed), xus[0])
        rew_mlp_state =train_state.TrainState.create(
            apply_fn = self.predict_r_w_fn,
            params = rew_params ,
            tx = self.optimizer
        )
        
        dones_params = self.reward_model.init(jax.random.PRNGKey(self.seed), xus[0])
        dones_mlp_state =train_state.TrainState.create(
            apply_fn = self.predict_d_w_fn,
            params = dones_params ,
            tx = self.optimizer
        )
        
        
        batch_size = self.batch_size
        print(len(states)//batch_size)
        num_steps_per_epoch = 100#len(states)//batch_size
        
        key = jrandom.PRNGKey(self.seed)
        with tqdm(range(self.epochs)) as tp:
            for i in tp:
                key, b_key = jrandom.split(key)
                
                batch_ordering = jrandom.permutation(b_key, jnp.arange(len(states)))
                epoch_dyn_loss= []
                epoch_rew_loss= []
                epoch_done_loss = []
            
                for j in range(len(states)//self.batch_size):
                    batch = batch_ordering[j*self.batch_size: (j+1)*self.batch_size]
                    
                    batch_states, batch_actions, batch_next_state, batch_rewards, batch_dones = states[batch], actions[batch], next_states[batch], rewards[batch], terminals[batch]
                    batch_weights = weights[batch]
                    mlp_state, loss_val = self.train_fn(mlp_state, batch_states, batch_actions, batch_next_state)
                    
                    rew_mlp_state, rew_loss = self.train_reward_fn(rew_mlp_state, batch_states, batch_actions, batch_rewards)
                    
                    
                    dones_mlp_state, done_loss = self.train_d_fn(dones_mlp_state, batch_states, batch_actions, batch_dones, batch_weights)
                    
                    
                    epoch_dyn_loss.append(loss_val)
                    epoch_rew_loss.append(rew_loss)                   
                    epoch_done_loss.append(done_loss)
                    
                
                tp.set_postfix(dynamics_loss = np.mean(epoch_dyn_loss), rewards_loss = np.mean(epoch_rew_loss), done_loss = np.mean(epoch_done_loss))
                    
        self.params = mlp_state.params
        self.rew_params = rew_mlp_state.params
        self.done_params = dones_mlp_state.params
        return self.params, self.rew_params, self.done_params, (rew_min, rew_max), (state_min, state_max)
    
    def train_q_network(self, data, target, behaviour, dynamics_params, reward_params, done_params, gamma):
        _, states, actions, _, next_states, rewards, policy_ratio, terminals = to_numpy(data, target, behaviour, return_terminals=True)
        
        xus = jnp.concatenate([states, actions], axis=-1)
        key = jrandom.PRNGKey(self.seed)        
        key, init_key1, init_key2 = jrandom.split(key, 3)
        q1_params = self.model.init(init_key1, xus[:20])
        q1_mlp_state = train_state.TrainState.create(
            apply_fn = self.model.apply,
            params = q1_params,
            tx = self.optimizer
        )
        
        q1_target_params = jax.tree_util.tree_map(lambda x: x, q1_params)
        
        
        q2_params = self.model.init(init_key2, xus[:20])
        q2_mlp_state = train_state.TrainState.create(
            apply_fn = self.model.apply,
            params = q2_params,
            tx = self.optimizer
        )
        
        q2_target_params = jax.tree_util.tree_map(lambda x: x, q2_params)
        
        max_rew = np.absolute(rewards).max()
        clip = max_rew/(1-gamma)
        
        
        
        with tqdm(range(self.epochs)) as tp:
            for i in tp:

                
                key, b_key = jrandom.split(key)
                batch_ordering = jrandom.permutation(b_key, jnp.arange(len(states)))
                q1_epoch_loss = []
                q2_epoch_loss = []
                #batch_states, batch_actions, batch_rewards, batch_next_states, batch_next_actions, batch_dones= None, None, None, None, None, None
                
                    
                for j in range(len(states)//self.batch_size):
                        batch = batch_ordering[j*self.batch_size: (j+1)*self.batch_size]
                        batch_states = states[batch]
                        
                        batch_actions = target.sample(batch_states)
                        
                        batch_next_states = np.array(self.predict_w_fn(dynamics_params, batch_states, batch_actions))
                        batch_rewards = self.predict_r_w_fn(reward_params, batch_states, batch_actions)
                        
                        batch_next_actions = target.sample(batch_next_states)
                        
                        key, done_key = jrandom.split(key)
                        batch_dones = self.sample_d(done_params, done_key, batch_states, batch_actions)
                
                        batch_dones = batch_dones.astype(jnp.float32)

                        q1_mlp_state, q1_loss_val = self.train_q_fn(q1_mlp_state, q1_target_params, q2_target_params, batch_states, batch_actions, batch_rewards, batch_next_states, batch_next_actions, batch_dones, gamma, clip=clip)
                        q2_mlp_state, q2_loss_val = self.train_q_fn(q2_mlp_state, q1_target_params, q2_target_params, batch_states, batch_actions, batch_rewards, batch_next_states, batch_next_actions, batch_dones, gamma, clip=clip)
                        
                        if j % 5 == 0:
                            q1_target_params = self.soft_update(q1_target_params, q1_mlp_state.params)
                            q2_target_params = self.soft_update(q2_target_params, q2_mlp_state.params)


                        q1_epoch_loss.append(q1_loss_val)
                        q2_epoch_loss.append(q2_loss_val)
                        
                        
                
                tp.set_postfix(q1_loss = np.mean(q1_epoch_loss), q2_loss = np.mean(q2_epoch_loss))
        self.q1_params = q1_mlp_state.params 
        self.q2_params = q2_mlp_state.params
        return self.q1_params, self.q2_params
            
            
        
        
    
    def simulate_rollout_return(self, key, dynamics_params, rewards_params, done_params, data:DataType, target: Policy, gamma:float, reward_clip, state_clip)->float:
        
        
        
        tau = random.choice(data)
        state = tau['states'][0]
        
        
        traj = np.zeros((self.horizon, state.shape[0]+tau['actions'].shape[-1]))
        total_reward = 0
        for t in range(self.horizon):
            
            
            action = target.sample(state)
            
            xus = jnp.concatenate([state, action], axis=-1)    
            
            traj[t] = xus
            
            reward = np.array(self.reward_model.apply(rewards_params, xus))
            reward = np.clip(reward, reward_clip[0], reward_clip[1])
            total_reward += (gamma **t)*reward   
             
            state = np.array(self.model.apply(dynamics_params, xus))
            if self.terminate_fn is not None:
                if self.terminate_fn(state):
                    break
        
            state = np.clip(state, state_clip[0], state_clip[1])
            
            key, done_key = jrandom.split(key)
            done_p = jax.lax.logistic(self.done_model.apply(done_params, xus)) 
            done = jrandom.bernoulli(done_key, done_p)[0]
            if done==1:
                break
        
        return total_reward, traj
    
    def evaluate(self, data, target, behavior, gamma = 1, reward_estimator=None):
        print("evaluating")
        if self.terminate_fn is not None:
            print("using terminate function and not the trained model")
        print(self.savepath)
        
        if os.path.exists(self.savepath): 
            
            print("loading model weights")
            weights_dict = pickle.load(open(self.savepath, 'rb'))
            
            self.dynamic_params = weights_dict['dynamics']
            self.reward_params = weights_dict['reward']
            self.done_params = weights_dict['done']
            self.reward_clip = weights_dict['reward_clip']
            self.state_clip = weights_dict['state_clip']
            dynamics_params = self.dynamic_params 
            reward_params = self.reward_params 
            done_params = self.done_params 
            reward_clip = self.reward_clip
            state_clip = self.state_clip
            
        elif not self.trained:
            print("training model")
            dynamics_params, reward_params, done_params, reward_clip, state_clip = self.train_dynamics_network(data, target, behavior)
            self.trained = True
            self.dynamic_params = dynamics_params
            self.reward_params = reward_params
            self.done_params = done_params
            self.reward_clip = reward_clip
            self.state_clip = state_clip
            
            weights_dict = {}
            weights_dict['dynamics'] = dynamics_params
            weights_dict['reward'] = reward_params
            weights_dict['done'] = done_params
            weights_dict['reward_clip'] = reward_clip
            weights_dict['state_clip'] = state_clip
            
            
            
            pickle.dump(weights_dict, open(self.savepath, 'wb'))
        
        
        else:
            dynamics_params = self.dynamic_params
            reward_params = self.reward_params
            done_params = self.done_params
            reward_clip = self.reward_clip
            state_clip = self.state_clip
            
        
        key = jrandom.PRNGKey(self.seed)
        rewards = []
        rollouts = []
        with tqdm(range(50)) as tp:
            for i in tp:
            
                key, rollout_key = jrandom.split(key)
                rollout_return, rollout = self.simulate_rollout_return(rollout_key, dynamics_params, reward_params, done_params, data, target, gamma, reward_clip, state_clip)
                
                
                rollouts.append(rollout)
                tp.set_postfix(reward = rollout_return)
                
                rewards.append(rollout_return)
        print(np.nanmean(rewards))
        
        
        
        pickle.dump(rollouts, open('model_based_rollouts.pkl','wb'))
        return np.nanmean(rewards)