import numpy as np
import copy
from collections import defaultdict
from preferences_offlineRL.envs.common import BasePolicy, EpsGreedyPolicy
from preferences_offlineRL.models.tabular_models import log, train_tabular_reward_model_w_uncertainty
from preferences_offlineRL.models.mlp_models import train_reward_model_w_uncertainty
from tqdm import tqdm
import itertools

from preferences_offlineRL.utils.preference_utils import generate_rollouts, annotate_buffer, uncertainty_order_pairs, get_pessimistic_environment
from preferences_offlineRL.utils.experiment_utils import log_offline_rl

def get_optimistic_environment(env_base, transitions_ci, rewards_ci, u_weight_t=0.5, u_weight_r= 0.1):

    env_optimistic = copy.deepcopy(env_base)
    env_optimistic.rewards = env_base.rewards + u_weight_t*transitions_ci
    if rewards_ci is not None: 
        env_optimistic.rewards += u_weight_r*rewards_ci 
    
    # get optimistic policy
    optimistc_policy = env_optimistic.get_lp_solution(return_value=False)

    return optimistc_policy, env_optimistic

def get_optimistic_R_pessimistic_T_environment(env_base, transitions_ci, rewards_ci, u_weight_t=0.1, u_weight_r=0.01,
                                               N_samples=10):
    _, env_pessimistic = get_pessimistic_environment(env_base, transitions_ci, rewards_ci, u_weight_t, u_weight_r, N_samples)
    _, env_optimistic = get_optimistic_environment(env_base, transitions_ci, rewards_ci, u_weight_t, u_weight_r)
    env_pessimistic.rewards = env_optimistic.rewards
    policy = env_pessimistic.get_lp_solution(return_value=False)
    return policy, env_pessimistic



def get_optimal_policy_set(reward_model, _env):
    env = copy.deepcopy(_env)
    policies = []
    for reward in reward_model.models:
        env.rewards = reward.extract_reward_vector()
        policies.append(env.get_lp_solution(return_value=False))
    return policies


def record_preference_experiment_w_uncertainty(env_1,
                                    env_eval,
                                    solution_pi_eval,
                                    policy1='optimal',
                                    transitions1='default',
                                    tabular=True,
                                    sampling = 'all',
                                    N_experiments=10,
                                    N_samples=500,
                                    N_rollouts=50,
                                    ):
    results = []
    N_iterations = N_samples // N_rollouts
    env_1, transitions_ci = env_1
    rewards_ci=None
    if not tabular:
        reward_model = None
    for i in range(N_experiments):
        np.random.seed(i)
        logging = defaultdict(list)
        env_updated = copy.deepcopy(env_1)
        solution_pi_latest = env_updated.get_lp_solution(return_value=False)
        logging = log(logging, env_updated, solution_pi_latest, env_eval, solution_pi_eval)

        for it in tqdm(range(N_iterations)):

            if transitions1 == 'online': # PbOP
                env_rollout = env_eval
            elif it == 0 or transitions1 == 'default': # Sim-OPRL without pessimism in rollouts
                env_rollout = env_updated
            elif transitions1 == 'pessimistic': # Sim-OPRL
                _, env_rollout = get_pessimistic_environment(env_updated, transitions_ci, rewards_ci=None)

            if policy1 == 'optimal': # Sim-OPRL without pessimism in rollouts
                epspol = solution_pi_latest
            elif policy1 == 'optimistic': # PbOP
                #epspol, env_optimistic = get_optimistic_environment(env_updated, transitions_ci, rewards_ci=rewards_ci)
                env_optimistic = copy.deepcopy(env_updated)
                env_optimistic.rewards = env_updated.rewards - 0.5*transitions_ci
                epspol = get_optimal_policy_set(reward_model, env_optimistic)

            elif policy1 == 'optimistic_R_pessimistic_T': # Sim-OPRL
                env_pessimistic = copy.deepcopy(env_updated)
                env_pessimistic.rewards = env_updated.rewards - 0.5*transitions_ci
                # if rewards_ci is not None:
                #     env_pessimistic.rewards = env_updated.rewards + 0.01*rewards_ci
                epspol = get_optimal_policy_set(reward_model, env_pessimistic)
            elif 'eps' in policy1:
                epspol = EpsGreedyPolicy(solution_pi_latest, 
                                        eps=float(policy1.replace('eps', '')),
                                        action_space=env_eval.action_space)
            
            env_rollout2 = env_rollout
            epspol2 = epspol
            if sampling == 'all' or reward_model is None:
                list_trajs_pairs = generate_rollouts(env_rollout, epspol, env_rollout2, epspol2, N_rollouts)
            elif sampling == 'uncertainty':
                multiplication_factor=10
                list_trajs_pairs = []
                if isinstance(epspol, list):
                    for pi1, pi2 in itertools.combinations(epspol, 2):
                        list_trajs_pairs.extend(generate_rollouts(env_rollout, pi1, env_rollout2, pi2, N_rollouts*multiplication_factor))
                else:
                    list_trajs_pairs.extend(generate_rollouts(env_rollout, epspol, env_rollout2, epspol2, N_rollouts*multiplication_factor))
                list_trajs_pairs = uncertainty_order_pairs(list_trajs_pairs, 
                                                    reward_model.uncertainty_traj_pair,
                                                    env_eval.N_states,
                                                    env_eval.N_actions)
            list_trajs = annotate_buffer(list_trajs_pairs, env_eval, N_rollouts)

            if tabular:
                env_updated, rewards_ci = train_tabular_reward_model_w_uncertainty(list_trajs, env_updated, rewards_ci)
            else:
                env_updated, reward_model, rewards_ci, reward_losses = train_reward_model_w_uncertainty(list_trajs, env_updated, reward_model=reward_model)
                logging['reward_losses'].append(reward_losses)
                
            solution_pi_latest = env_updated.get_lp_solution(return_value=False)
            logging = log(logging, env_updated, solution_pi_latest, env_eval, solution_pi_eval)
            logging = log_offline_rl(logging,
                                     get_pessimistic_environment(env_updated, 
                                                                 transitions_ci=transitions_ci, 
                                                                 rewards_ci=rewards_ci)[0],
                                      env_eval, 
                                      solution_pi_eval)
        results.append(logging)

    return results
