from typing import List, Optional, Tuple
import numpy as np
import random
import copy
from collections import defaultdict
import itertools
from preferences_offlineRL.models.tabular_models import log, train_tabular_reward_model, train_tabular_reward_transition_models, train_tabular_reward_model_w_uncertainty, train_tabular_reward_transition_models_w_uncertainty
from preferences_offlineRL.models.mlp_models import train_reward_model, train_reward_model_w_uncertainty, convert_states_actions_one_hot
import gym
from tqdm import tqdm

from preferences_offlineRL.utils.preference_utils import compute_log_likelihood_traj, compute_rewards_traj,get_pessimistic_environment
from preferences_offlineRL.utils.experiment_utils import log_offline_rl



def annotate_buffer(traj_pairs_buffer: List[Tuple], 
                    annotation_env: gym.Env, 
                    N_rollouts: int):
    """Annotate buffer of offline trajectories.
    
    Only the first N_rollouts pairs are annotated."""
    list_trajs = []
    for i in range(min(len(traj_pairs_buffer), N_rollouts)):
        traj_1, traj_2 = traj_pairs_buffer[i]
        P_1 = compute_log_likelihood_traj(traj_1, annotation_env.transitions)
        R_1 = compute_rewards_traj(traj_1, annotation_env.rewards, annotation_env.discount_factor)
        P_2 = compute_log_likelihood_traj(traj_2, annotation_env.transitions)
        R_2 = compute_rewards_traj(traj_2, annotation_env.rewards, annotation_env.discount_factor)
        
        y_t = int(P_1 >= P_2)
        if R_1 == R_2: y_r = random.randint(0, 1)
        else: y_r = int(R_1 > R_2)
        list_trajs.append(
            [traj_1, traj_2, y_t, y_r]
            )

    if len(traj_pairs_buffer) < N_rollouts:
        print("Warning: not enough samples in buffer")
        while len(list_trajs) < N_rollouts:
            list_trajs.append(random.choice(list_trajs))
    return list_trajs


def random_sample_buffer(trajs_buffer: List[List]):
    """Sample from offline trajectories buffer."""
    traj_pairs_buffer = list(itertools.combinations(trajs_buffer, 2))
    random.shuffle(traj_pairs_buffer)
    return traj_pairs_buffer


def uncertainty_order_pairs(traj_pairs_buffer, uncertainty_fn, N_states, N_actions):
    traj_pairs_buffer = sorted(traj_pairs_buffer, key=lambda x: 
            uncertainty_fn(
                convert_states_actions_one_hot(x[0], N_states, N_actions),
                convert_states_actions_one_hot(x[1], N_states, N_actions)), reverse=True)
    return traj_pairs_buffer


def uncertainty_sample_buffer(trajs_buffer, max_num_pairs, uncertainty_fn, N_states, N_actions):
    """Sample from offline trajectories buffer."""
    traj_pairs_buffer = list(itertools.combinations(trajs_buffer, 2))
    idx = np.random.choice(np.arange(len(traj_pairs_buffer)), max_num_pairs, replace=False)
    traj_pairs_buffer = [traj_pairs_buffer[i] for i in idx]
    return uncertainty_order_pairs(traj_pairs_buffer, uncertainty_fn, N_states, N_actions)



def record_preference_experiment_offline_trajs(trajs_buffer,
                                    env_1,
                                    env_eval,
                                    solution_pi_eval,
                                    sampling = 'random',
                                    N_experiments=10,
                                    N_samples=500,
                                    N_rollouts=50,
                                    tabular = True
                                    ):
    """Run preference sampling experiment over offline trajectories buffer."""
    N_iterations = N_samples // N_rollouts
    env_1, transitions_ci = env_1

    reward_model = None
    results = []
    for i in range(N_experiments):
        np.random.seed(i)
        logging = defaultdict(list)
        env_updated = copy.deepcopy(env_1)
        solution_pi_latest = env_updated.get_lp_solution(return_value=False)
        logging = log(logging, env_updated, solution_pi_latest, env_eval, solution_pi_eval)

        for _ in tqdm(range(N_iterations)):
            if sampling == 'random' or (sampling == 'uncertainty' and reward_model is None):
                list_trajs = random_sample_buffer(trajs_buffer)
            elif sampling == 'uncertainty':
                multiplication_factor=10
                list_trajs = uncertainty_sample_buffer(trajs_buffer, 
                                                    N_rollouts*multiplication_factor, 
                                                    reward_model.uncertainty_traj_pair,
                                                    env_eval.N_states,
                                                    env_eval.N_actions)
            else:
                raise NotImplementedError
            list_trajs = annotate_buffer(list_trajs, env_eval, N_rollouts)

            if tabular:
                env_updated = train_tabular_reward_model(list_trajs, env_updated)
            else:
                #env_updated, reward_model, reward_losses = train_reward_model(list_trajs, env_updated, reward_model=reward_model)
                env_updated, reward_model, rewards_ci, reward_losses = train_reward_model_w_uncertainty(list_trajs, env_updated, reward_model=reward_model)
                logging['reward_losses'].append(reward_losses)
            solution_pi_latest = env_updated.get_lp_solution(return_value=False)
            logging = log(logging, env_updated, solution_pi_latest, env_eval, solution_pi_eval)
            logging = log_offline_rl(logging, 
                                     get_pessimistic_environment(env_updated, transitions_ci=transitions_ci, 
                                                                 rewards_ci=rewards_ci)[0],
                                      env_eval, 
                                      solution_pi_eval)
        results.append(logging)
    return results

