
import copy
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from preferences_offlineRL.envs.common import TabularMDP
EPS = np.finfo(np.float32).eps



class TabularTransitionModel(nn.Module):
    """Transition model for tabular MDP.
    
    Computes the logprob of a trajectory by multiplying triplet counts (s,a,s') 
    with log transition matrix.
    """
    def __init__(self, N_states, N_actions, init_matrix=None):
        super(TabularTransitionModel, self).__init__()
        self.N_states = N_states
        self.N_actions = N_actions
        if init_matrix is None:
            init_matrix = torch.rand(N_actions, N_states, N_states)
        else:
            assert init_matrix.shape == (N_actions, N_states, N_states)
            init_matrix = torch.from_numpy(copy.deepcopy(init_matrix))
            init_matrix += EPS
        self.transition_matrix = nn.Parameter(init_matrix)
        self.normalize()

    def forward(self, triple_count):
        return torch.matmul(triple_count, self.transition_matrix.reshape(-1, 1))
    
    def normalize(self,):
        with torch.no_grad():
            self.transition_matrix-=torch.logsumexp(self.transition_matrix, -1, keepdim=True)
    
class TabularRewardModel(nn.Module):
    """Reward model for tabular MDP.
    
    Computes the return of a trajectory by multiplying discounted state counts 
    with reward vector.
    """
    def __init__(self, N_states, init_vec=None):
        super(TabularRewardModel, self).__init__()
        self.N_states = N_states
        if init_vec is None:
            init_vec = torch.rand(N_states)
        else:
            assert init_vec.shape == (N_states,)
            init_vec = torch.from_numpy(copy.deepcopy(init_vec))
        self.reward_vector = nn.Parameter(init_vec)

    def forward(self, state_count):
        return torch.matmul(state_count, self.reward_vector.float())
    
    def normalize(self,): pass

class TabularPreferenceDataset(Dataset):
    """Preference dataset for tabular MDP.
    
    Trajectories are converted to (s,a,s') triplet counts and discounted state counts.
    """
    def __init__(self, annotated_trajs, N_states=5, N_actions=2, discount_factor=0.9):
        self.annotated_trajs = annotated_trajs
        self.N_states = N_states
        self.N_actions = N_actions
        self.discount_factor = discount_factor
    
    def __len__(self):
        return len(self.annotated_trajs)
    
    def get_triple_counts(self, traj):
        triple_counts = np.zeros(self.N_states * self.N_actions * self.N_states)
        for i in range(0, len(traj)-3, 3):
            s, a, r, sprime = traj[i], traj[i+1], traj[i+2], traj[i+3]
            triple_counts[a*(self.N_states * self.N_states) + s*self.N_states + sprime] += 1
        return triple_counts
    
    def get_state_counts(self, traj):
        state_counts = np.zeros(self.N_states)
        for i in range(0, len(traj), 3):
            s = traj[i]
            state_counts[s] += self.discount_factor**i
        return state_counts
    
    def __getitem__(self, idx):
        traj_1, traj_2, y_T, y_R = self.annotated_trajs[idx]
        return (
            torch.tensor(self.get_triple_counts(traj_1)), torch.tensor(self.get_triple_counts(traj_2)),
            torch.tensor(self.get_state_counts(traj_1)).float(), torch.tensor(self.get_state_counts(traj_2)).float(),
            torch.tensor([y_T]).float(), torch.Tensor([y_R]).float()
            )
    
    
def loss_fn_T(batch, log_transition):
    traj_1, traj_2, _, _, y_T, _ = batch
    y_hat_logit = log_transition(traj_1) - log_transition(traj_2)
    return nn.BCEWithLogitsLoss()(y_hat_logit, y_T)

def loss_fn_R(batch, reward_model):
    _, _, traj_1, traj_2, _, y_R = batch
    y_hat_logit = reward_model(traj_1) - reward_model(traj_2)
    return nn.BCEWithLogitsLoss()(y_hat_logit, y_R.squeeze())


def train_model(model, dataloader, optimizer, loss_fn, n_epochs=10, verbose=False):
    for epoch in range(n_epochs):
        losses = []
        for batch in dataloader:
            optimizer.zero_grad()
            loss = loss_fn(batch, model)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            model.normalize()
        if verbose: print("Loss:", np.mean(losses))
    return model

def train_tabular_reward_transition_models(list_trajs, 
                                   env_emp, 
                                   verbose=False):
    """Train reward and transitions models after sampling iteration."""
    N_states, N_actions = env_emp.N_states, env_emp.N_actions
    dataset = TabularPreferenceDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    if verbose:
        print("Training transition model")
    log_transition_model = TabularTransitionModel(N_states, N_actions, np.log(env_emp.transitions + EPS))
    optimizer = torch.optim.Adam(log_transition_model.parameters(), lr=0.01)
    log_transition_model = train_model(log_transition_model, dataloader, optimizer, loss_fn_T, n_epochs=2)
    transitions = np.exp(log_transition_model.transition_matrix.detach().numpy())
    if verbose:
        print(transitions)

        print("Training reward model")
    reward_model = TabularRewardModel(N_states, env_emp.rewards)
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=0.01)
    reward_model = train_model(reward_model, dataloader, optimizer, loss_fn_R, n_epochs=2)
    rewards = reward_model.reward_vector.detach().numpy()
    if verbose:
        print(rewards)
    
    env_updated = copy.deepcopy(env_emp)
    env_updated.transitions = transitions
    env_updated.rewards = rewards
    
    return env_updated


def train_tabular_reward_model(list_trajs, 
                                   env_emp, 
                                   verbose=False):
    
    """Train reward model after sampling iteration."""
    N_states, N_actions = env_emp.N_states, env_emp.N_actions
    dataset = TabularPreferenceDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    if verbose:
        print("Training reward model")
    reward_model = TabularRewardModel(N_states, env_emp.rewards)
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=0.01)
    reward_model = train_model(reward_model, dataloader, optimizer, loss_fn_R, n_epochs=2)
    rewards = reward_model.reward_vector.detach().numpy()
    if verbose:
        print(rewards)
    
    env_updated = copy.deepcopy(env_emp)
    env_updated.rewards = rewards
    return env_updated


def train_tabular_reward_transition_models_w_uncertainty(list_trajs,
                                                 env_emp,
                                                 transitions_ci,
                                                 rewards_ci,
                                                 verbose=False):
    env_emp = train_tabular_reward_transition_models(list_trajs, env_emp, verbose=verbose)
    transitions_ci, rewards_ci = update_hoeffding_interval_from_trajs(list_trajs, 
                                                                      transitions_ci, 
                                                                      rewards_ci,
                                                                      N_states=env_emp.N_states,
                                                                      N_actions=env_emp.N_actions
                                                                      )
    return env_emp, transitions_ci, rewards_ci

def train_tabular_reward_model_w_uncertainty(list_trajs,
                                                 env_emp,
                                                 rewards_ci,
                                                 verbose=False):
    env_emp = train_tabular_reward_model(list_trajs, env_emp, verbose=verbose)
    _, rewards_ci = update_hoeffding_interval_from_trajs(list_trajs, 
                                                                      transitions_ci=None, 
                                                                      rewards_ci=rewards_ci,
                                                                      N_states=env_emp.N_states,
                                                                      N_actions=env_emp.N_actions
                                                                      )
    return env_emp, rewards_ci

def log(logging, env_updated, solution_pi_latest, env_eval, solution_pi_eval):
    """Log metrics after the sampling iteration."""
    logging['transitions'].append(env_updated.transitions)
    logging['rewards'].append(env_updated.rewards)
    # this is specifically for double chain: have we identified the first state as having high reward?
    logging['R_correct'].append(env_updated.rewards[0] > env_updated.rewards[-1])
    logging['T_dist_L1'].append( np.sum(np.abs(env_updated.transitions - env_eval.transitions)))
    logging['subopt'].append(env_eval.evaluate_policy(solution_pi_eval) - env_eval.evaluate_policy(solution_pi_latest))
    logging['policy'].append(solution_pi_latest.matrix)
    return logging



def update_hoeffding_interval_from_trajs(trajectories,
                                         transitions_ci,
                                         rewards_ci,
                                 N_states=5, 
                                 N_actions=2,
                                 delta = 0.05,
                                 R_max = 10,
                                 ):
    transitions_counts, rewards_counts = np.zeros((N_actions, N_states, N_states)), np.zeros((N_states))
    for traj1, traj2, _, _ in trajectories:
        for i in range(0, len(traj1)-3, 3):
            s, a, r, sprime = traj1[i], traj1[i+1], traj1[i+2], traj1[i+3]
            rewards_counts[sprime] += 1
            transitions_counts[a, s, sprime] += 1

            s, a, r, sprime = traj2[i], traj2[i+1], traj2[i+2], traj2[i+3]
            rewards_counts[sprime] += 1
            transitions_counts[a, s, sprime] += 1

    transitions_counts = np.maximum(transitions_counts, 1) / (2*len(traj1))
    rewards_counts = np.maximum(rewards_counts, 1) / (2*len(traj1))
            
    rewards_ci_old = rewards_ci / R_max
    if transitions_ci is not None:
        transitions_ci_new_sq = transitions_ci**2 \
            / (1 + (2*transitions_counts* transitions_ci**2) / np.log(4*N_states**2 *N_actions/delta))
        transitions_ci_new = np.sqrt(transitions_ci_new_sq)
    else:
        transitions_ci_new = None
    rewards_ci_new_sq = rewards_ci_old ** 2 \
        /(1 + (2*rewards_counts * rewards_ci_old**2) / np.log(4*N_states/delta) )

    return transitions_ci_new, R_max * np.sqrt(rewards_ci_new_sq)
