
import copy
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from preferences_offlineRL.envs.common import TabularMDP
EPS = np.finfo(np.float32).eps



class MLPTransitionModel(nn.Module):
    def __init__(self, N_states, N_actions):
        super(MLPTransitionModel, self).__init__()
        self.N_states = N_states
        self.N_actions = N_actions
        self.fc = nn.Linear(N_states + N_actions, N_states, bias=False)
        torch.nn.init.kaiming_uniform_(self.fc.weight)

    def forward(self, state, action):
        input = torch.cat((state, action), dim=-1)
        next_state = self.fc(input)
        return next_state
    
    def log_forward_traj(self, traj):
        logprob = 0
        for state, action, next_state in zip(traj[:-2:2], traj[1:-1:2], traj[2::2]):
            next_state_prob = nn.LogSoftmax()(self.forward(state, action))
            logprob += torch.sum(next_state_prob * next_state, -1)
        return logprob

    def extract_transition_matrix(self):
        transition_matrix = torch.zeros(self.N_actions, self.N_states, self.N_states)
        for i in range(self.N_actions):
            for j in range(self.N_states):
                state = torch.zeros(1, self.N_states)
                state[:, j] = 1  # One-hot encoding for state j
                action = torch.zeros(1, self.N_actions)
                action[:, i] = 1  # One-hot encoding for action i
                next_state_prob = nn.Softmax()(self.forward(state, action))
                transition_matrix[i, j, :] = next_state_prob.squeeze()
        return transition_matrix.detach().numpy().astype(np.double)

class MLPRewardModel(nn.Module):
    def __init__(self, N_states, discount_factor=0.9):
        super(MLPRewardModel, self).__init__()
        self.N_states = N_states
        self.fc = nn.Linear(N_states, 1, bias=False)
        self.discount_factor = discount_factor
        torch.nn.init.kaiming_uniform_(self.fc.weight)

    def forward(self, state):
        return self.fc(state)

    def extract_reward_vector(self):
        reward_vector = torch.zeros(self.N_states)
        for j in range(self.N_states):
            state = torch.zeros(1, self.N_states)
            state[:, j] = 1
            reward = self.forward(state)
            reward_vector[j] = reward.squeeze()
        return reward_vector.detach().numpy().astype(np.double)
    
    def forward_traj(self, traj):
        reward = 0 # tensor?
        for t, state in enumerate(traj[::2]):
            # see PreferenceDataset: only includes (states,actions)
            reward += self.discount_factor**t * self.forward(state)
        return reward

    def preference_traj_pair(self, traj_1, traj_2):
        return nn.Sigmoid()(self.forward_traj(traj_1) - self.forward_traj(traj_2))
    

class EnsembleModel(nn.Module): 
    def __init__(self, model_module, ensemble_size=5, **init_args):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList([model_module(**init_args) for _ in range(ensemble_size)])

    def extract_meanvar_matrix(self, extract_fn_name):
        transitions = np.stack([getattr(model, extract_fn_name)() for model in self.models])
        return transitions.mean(axis=0), transitions.var(axis=0)

    def uncertainty_traj_pair(self, traj_1, traj_2):
        predictions = np.stack([model.preference_traj_pair(traj_1, traj_2).detach().numpy() for model in self.models])
        return predictions.var(axis=0)




class TrajectoriesDataset(Dataset):
    def __init__(self, trajs, N_states=5, N_actions=2):
        print(trajs)
        states = torch.tensor([traj[:-3:3] for traj in trajs]).view(-1)
        actions = torch.tensor([traj[1:-2:3] for traj in trajs]).view(-1)
        self.rewards = torch.tensor([traj[2:-1:3] for traj in trajs]).view(-1).float()
        next_states = torch.tensor([traj[3::3] for traj in trajs]).view(-1)

        # One-hot encode:
        self.states = nn.functional.one_hot(states, num_classes=N_states).float()
        self.actions = nn.functional.one_hot(actions, num_classes=N_actions).float()
        self.next_states = nn.functional.one_hot(next_states, num_classes=N_states).float()

    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return self.states[idx], self.actions[idx], self.rewards[idx], self.next_states[idx]

def convert_states_actions_one_hot(traj, N_states, N_actions):
    one_hot = []
    for i, el in enumerate(traj):
        if i % 3 == 0: # state
            one_hot.append(nn.functional.one_hot(torch.tensor(el).to(torch.int64), num_classes=N_states).float())
        elif i % 3 == 1: # action
            one_hot.append(nn.functional.one_hot(torch.tensor(el).to(torch.int64), num_classes=N_actions).float())
    return one_hot

class PreferenceDataset(Dataset):
    """Preference dataset for general MDP.
    """
    def __init__(self, annotated_trajs, N_states=5, N_actions=2, discount_factor=0.9):
        self.annotated_trajs = annotated_trajs
        self.N_states = N_states
        self.N_actions = N_actions
        self.discount_factor = discount_factor
    
    def __len__(self):
        return len(self.annotated_trajs)
    
    def __getitem__(self, idx):
        traj_1, traj_2, y_T, y_R = self.annotated_trajs[idx]
        return (
            convert_states_actions_one_hot(traj_1, self.N_states, self.N_actions), 
            convert_states_actions_one_hot(traj_2, self.N_states, self.N_actions),
            torch.tensor([y_T]).float(), torch.Tensor([y_R]).float()
                )

def loss_fn_T_trajs(batch, transition):
    state, action, _, next_state = batch
    pred_next_state = transition(state, action)
    return nn.CrossEntropyLoss()(pred_next_state, next_state)

def loss_fn_T_prefs(batch, transition):
    traj_1, traj_2, y_T, _ = batch
    y_hat_logit = transition.log_forward_traj(traj_1) - transition.log_forward_traj(traj_2)
    return nn.BCEWithLogitsLoss()(y_hat_logit, y_T.squeeze())

def loss_fn_R_trajs(batch, reward_model):
    state, _, reward, _ = batch
    pred_rew = reward_model(state)
    return nn.MSELoss()(pred_rew, reward)

def loss_fn_R_prefs(batch, reward_model):
    traj_1, traj_2, _, y_R = batch
    y_hat_logit = reward_model.forward_traj(traj_1) - reward_model.forward_traj(traj_2)
    return nn.BCEWithLogitsLoss()(y_hat_logit, y_R)


def train_model(model,   
                dataloader, 
                optimizer, 
                loss_fn, 
                n_epochs=10, 
                verbose=False):
    loss_per_epoch = []
    for epoch in range(n_epochs):
        losses = []
        for batch in dataloader:
            optimizer.zero_grad()
            loss = loss_fn(batch, model)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        if verbose: 
            print("Loss:", np.mean(losses))
        loss_per_epoch.append(np.mean(losses))
    return model, loss_per_epoch

def train_ensemble_model(ensemble_model, 
                         dataloader, 
                         optimizer, 
                         loss_fn, 
                         bootstrap_rate=0.1, 
                         n_epochs=10, 
                         verbose=False):
    loss_per_epoch = []
    for epoch in range(n_epochs):
        losses = []
        for batch in dataloader:
            for model in ensemble_model.models:
                if np.random.rand() < bootstrap_rate: continue
                optimizer.zero_grad()
                loss = loss_fn(batch, model)
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
        if verbose: print("Loss:", np.mean(losses))
        loss_per_epoch.append(np.mean(losses))
    return ensemble_model, loss_per_epoch

def train_reward_transition_models(list_trajs, 
                                        env_emp,
                                        transition_model = None, 
                                        reward_model = None,
                                        input = 'trajectories',
                                        verbose=False):
    """Train reward and transitions models from observed trajectories/sampling iterations."""
    print(input)
    N_states, N_actions = env_emp.N_states, env_emp.N_actions
    if input == 'trajectories':
        dataset = TrajectoriesDataset(list_trajs, N_states, N_actions)  
    elif input == 'preferences':
        dataset = PreferenceDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=min(4, len(list_trajs)), shuffle=True)
    
    if verbose:
        print("Training transition model")
    if transition_model is None:
        transition_model = MLPTransitionModel(N_states, N_actions)

    optimizer = torch.optim.Adam(transition_model.parameters(), lr=0.01)
    transition_model, transition_losses = train_model(transition_model, 
                                   dataloader, 
                                   optimizer,
                                   loss_fn_T_trajs if input == 'trajectories' else loss_fn_T_prefs, 
                                   n_epochs=5, verbose=verbose)
    transitions = transition_model.extract_transition_matrix()
    if verbose:
        print(transitions)    

        print("Training reward model")

    if reward_model is None:
        reward_model = MLPRewardModel(N_states)
        
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=0.01)
    reward_model, reward_losses = train_model(reward_model, 
                               dataloader, 
                               optimizer, 
                               loss_fn_R_trajs if input == 'trajectories' else loss_fn_R_prefs, 
                               n_epochs=5, 
                               verbose=verbose)
    rewards = reward_model.extract_reward_vector()
    if verbose:
        print(rewards)

    env_updated = copy.deepcopy(env_emp)
    env_updated.transitions = transitions
    env_updated.rewards = rewards


    return env_updated, transition_model, reward_model, transition_losses, reward_losses


def train_transition_model(list_trajs,
                                    env_emp, 
                                         verbose=False):
    N_states, N_actions = env_emp.N_states, env_emp.N_actions

    dataset = TrajectoriesDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=min(4, len(list_trajs)), shuffle=True)

    if verbose:
        print("Training transition model")

    transition_model = MLPTransitionModel(N_states, N_actions)

    optimizer = torch.optim.Adam(transition_model.parameters(), lr=0.01)
    transition_model, transition_losses = train_model(transition_model, 
                                   dataloader, 
                                   optimizer,
                                   loss_fn_T_trajs,
                                   n_epochs=5, verbose=verbose)
    transitions = transition_model.extract_transition_matrix()
    if verbose:
        print(transitions)  


    env_updated = copy.deepcopy(env_emp)
    env_updated.transitions = transitions

    return env_updated, transition_model, transition_losses


def train_transition_model_w_uncertainty(list_trajs,
                                      env_emp,
                                      verbose = False):
    N_states, N_actions = env_emp.N_states, env_emp.N_actions

    dataset = TrajectoriesDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=min(4, len(list_trajs)), shuffle=True)

    if verbose:
        print("Training transition model")
    transition_model = EnsembleModel(MLPTransitionModel, N_states=N_states, N_actions=N_actions)

    optimizer = torch.optim.Adam(transition_model.parameters(), lr=0.01)
    transition_model, transition_losses = train_ensemble_model(transition_model, 
                                   dataloader, 
                                   optimizer,
                                   loss_fn_T_trajs,
                                   n_epochs=5, verbose=verbose)

    transitions, transitions_ci = transition_model.extract_meanvar_matrix("extract_transition_matrix")
    if verbose:
        print(transitions)  


    env_updated = copy.deepcopy(env_emp)
    env_updated.transitions = transitions

    return env_updated, transition_model, transitions_ci, transition_losses

def train_reward_model(list_trajs, 
                                   env_emp, 
                                    reward_model = None,
                                    #input = 'trajectories',
                                   verbose=False):
    
    """Train reward model after sampling iteration."""
    N_states, N_actions = env_emp.N_states, env_emp.N_actions

    dataset = PreferenceDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    if verbose:
        print("Training reward model")

    if reward_model is None:
        reward_model = MLPRewardModel(N_states)
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=0.01)
    reward_model, reward_losses  = train_model(reward_model, dataloader, optimizer, loss_fn_R_prefs, n_epochs=2)
    rewards = reward_model.extract_reward_vector()
    if verbose:
        print(rewards)
    
    env_updated = copy.deepcopy(env_emp)
    env_updated.rewards = rewards
    return env_updated, reward_model, reward_losses


def train_reward_model_w_uncertainty(list_trajs, 
                                   env_emp, 
                                    reward_model = None,
                                   verbose=False):
    
    N_states, N_actions = env_emp.N_states, env_emp.N_actions
    dataset = PreferenceDataset(list_trajs, N_states, N_actions)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    if verbose:
        print("Training reward model")

    if reward_model is None:
        reward_model = EnsembleModel(MLPRewardModel, N_states=N_states)
    optimizer = torch.optim.Adam(reward_model.parameters(), lr=0.01)
    reward_model, reward_losses  = train_ensemble_model(reward_model, dataloader, optimizer, loss_fn_R_prefs, n_epochs=2)
    rewards, rewards_ci = reward_model.extract_meanvar_matrix("extract_reward_vector")
    if verbose:
        print(rewards)
    
    env_updated = copy.deepcopy(env_emp)
    env_updated.rewards = rewards
    return env_updated, reward_model, rewards_ci, reward_losses