from typing import List, Optional, Tuple
import numpy as np
import copy
from preferences_offlineRL.envs.common import BasePolicy

import gym

EPS = np.finfo(np.float32).eps

def compute_log_likelihood_traj(traj: List[List], transition_matrix: np.ndarray) -> float:
    """Compute log like of a trajectory under a transition matrix."""
    N_actions, N_states, _ = transition_matrix.shape
    def get_triple_counts(traj):
        triple_counts = np.zeros((N_states * N_actions * N_states))
        for i in range(0, len(traj)-3, 3):
            s, a, r, sprime = traj[i], traj[i+1], traj[i+2], traj[i+3]
            triple_counts[a*(N_states * N_states) + s*N_states + sprime] += 1
        return triple_counts
    
    triple_counts = get_triple_counts(traj)
    return np.log(transition_matrix + EPS).flatten().dot(triple_counts)

def compute_rewards_traj(traj: List[List], reward_vector: np.ndarray, discount_factor: float) -> float:
    """Computes discounted rewards of a trajectory."""
    N_states = reward_vector.shape
    def get_state_counts(traj):
        state_counts = np.zeros((N_states))
        for i in range(0, len(traj), 3):
            s, a, r = traj[i], traj[i+1], traj[i+2]
            state_counts[s] += discount_factor**i
        return state_counts
    
    state_counts = get_state_counts(traj)
    return reward_vector.dot(state_counts)

def rollout_policy_in_env(env: gym.Env, policy: BasePolicy, compute_pr : bool = False):
    obs = env.reset()
    done = False
    traj = [obs[0]]
    while not done:
        a = policy.get_action(obs)
        obs, reward, done, _ = env.step(a)
        traj.extend([a, reward, obs[0]])
    traj.pop()

    if compute_pr:
        P = compute_log_likelihood_traj(traj, env.transitions)
        R = compute_rewards_traj(traj, env.rewards, 0.9)
        return traj, P, R
    return traj

def generate_rollouts(rollout_env1: gym.Env, 
                      policy_1: BasePolicy, 
                      rollout_env2: Optional[gym.Env] =None,
                      policy_2 : Optional[BasePolicy] =None,
                      num_rollouts: int = 10):
    """Generate trajectory rollouts."""
    traj_pairs = []
    if policy_2 is None: policy_2 = policy_1
    if rollout_env2 is None: rollout_env2 = rollout_env1
    for _ in range(num_rollouts):
        traj_1 = rollout_policy_in_env(rollout_env1, policy_1) 
        traj_2 = rollout_policy_in_env(rollout_env2, policy_2)

        traj_pairs.append([traj_1, traj_2])
    return traj_pairs



def get_pessimistic_environment(env_base, transitions_ci, rewards_ci=None, u_weight_t=0.5, u_weight_r=0.1):
    env_pessimistic = copy.deepcopy(env_base)
    env_pessimistic.rewards = env_base.rewards - u_weight_t*transitions_ci
    if rewards_ci is not None:
        env_pessimistic.rewards -= u_weight_r*rewards_ci
        
    pessimistic_policy = env_pessimistic.get_lp_solution(return_value=False)
    return pessimistic_policy, env_pessimistic