
import numpy as np
import torch
import gym
import torch.nn as nn
from tqdm import tqdm
import os
import dataclasses

import d3rlpy
from d3rlpy.dataset import TransitionMiniBatch

from delphicORL.algos import base

WINDOW_SIZE = 1024


class RewardFn(base.DemonstrationAlgorithm):
    def __init__(self, observation_space, action_space, 
                 demonstrations=None, 
                 custom_logger=None, 
                 test_demonstrations=None,
                 batch_size=32): #device='cuda')
        self.lstm=False
        super().__init__(demonstrations=demonstrations, custom_logger=custom_logger, 
                         test_demonstrations=test_demonstrations, batch_size=batch_size)
        obs_dim = observation_space.shape[0]
        if len(action_space.shape) < 1:
            act_dim = 1
        else:
            act_dim = action_space.shape[0]
        self.reward_model = nn.Sequential(nn.Linear(obs_dim+act_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1)).cuda()
        
        self.loss_fn = nn.MSELoss()

    def forward(self, state, act):
        if len(act.shape) < len(state.shape):
            act = torch.unsqueeze(act, -1)
        x = torch.cat([state, act],-1)
        return self.reward_model(x)     

    def fit(self, n_epochs=10, save_to=None):
        optimizer = torch.optim.Adam(self.reward_model.parameters(), lr=5e-4, weight_decay=1e-5)
        for epoch in tqdm(range(n_epochs)):
            for b, batch in enumerate(self._demo_data_loader):
                optimizer.zero_grad()
                state, action, reward = batch['obs'].float().cuda(), batch['acts'].float().cuda(), batch['rews'].float().cuda()
                pred_rew = torch.squeeze(self.forward(state, action), -1)
                loss = self.loss_fn(pred_rew, reward) # TODO for lstm format!! masks!
                loss.backward()
                optimizer.step()
                if epoch == 0 and b == 0: tqdm.write(f"Epoch {epoch}, first loss {loss.item():.3f}")   
            tqdm.write(f"Epoch {epoch}, last loss {loss.item():.3f}")
            tqdm.write(f"Training MSE: {self.reward_scorer(self._demo_data_loader):.3f}")
            tqdm.write(f"Validation MSE: {self.reward_scorer(self._test_demo_data_loader):.3f}")
            if save_to: self.save(save_to)
        
    def reward_scorer(self, dataloader):
        with torch.no_grad():
            total_diffs = []
            for batch in dataloader:
                state, action, reward = batch['obs'].float().cuda(), batch['acts'].float().cuda(), batch['rews'].float().cuda()
                red_rew = torch.squeeze(self.forward(state, action), -1)
                total_diffs += ( (reward - red_rew) ** 2 ).tolist()
            return float(np.mean(total_diffs))
            
    def save(self, save_to):
        torch.save(self.reward_model.state_dict(), os.path.join(save_to, "reward_model.pt"))

    def load(self, save_path, latest = False):
        if latest == True:
            save_path = os.path.join(save_path, np.sort(os.listdir(save_path))[-1])
        self.reward_model.load_state_dict(torch.load(os.path.join(save_path, 'reward_model.pt')))


class PropensityFn(base.DemonstrationAlgorithm):
    def __init__(self, observation_space, action_space, 
                 demonstrations=None, 
                 custom_logger=None,
                 test_demonstrations=None,
                 batch_size=32):
        self.lstm=False
        super().__init__(demonstrations=demonstrations, custom_logger=custom_logger, 
                         test_demonstrations=test_demonstrations, batch_size=batch_size)
        obs_dim = observation_space.shape[0]
        act_dim = max(1, len(action_space.shape))
        self.n_actions = max(2, np.prod(action_space.shape))
        final_activation = nn.Sigmoid() if self.n_actions == 2 else nn.Softmax(dim=-1)
        self.behav_policy = nn.Sequential(nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, self.n_actions if act_dim == 1 else act_dim),
            final_activation).cuda()
        
        self.loss_fn = nn.BCELoss() if self.n_actions == 2 else nn.CrossEntropyLoss()

    def forward(self, state):
        return self.behav_policy(state)     
    
    def predict_probability(self, state, act):
        torch.no_grad()
        if self.n_actions == 2:
            act_dist = torch.distributions.Bernoulli(self.behav_policy(state))
        else:
            act_dist = torch.distributions.Categorical(self.behav_policy(state))
        return act_dist.log_prob(torch.squeeze(act)).exp().detach().cpu().numpy()

    def fit(self, n_epochs=10, save_to=None):
        optimizer = torch.optim.Adam(self.behav_policy.parameters(), lr=5e-4, weight_decay=1e-5)
        for epoch in tqdm(range(n_epochs)):
            for b, batch in enumerate(self._demo_data_loader):
                optimizer.zero_grad()
                state, action = batch['obs'].float().cuda(), batch['acts'].float().cuda()
                #pred_rew = torch.squeeze(self.forward(state, action), -1)
                loss = self.loss_fn(self.forward(state), action.long()) # TODO for lstm format!! masks!
                loss.backward()
                optimizer.step()
                if epoch == 0 and b == 0: tqdm.write(f"Epoch {epoch}, first loss {loss.item():.3f}")   
            tqdm.write(f"Epoch {epoch}, last loss {loss.item():.3f}")
            tqdm.write(f"Training MSE: {self.prop_scorer(self._demo_data_loader):.3f}")
            tqdm.write(f"Validation MSE: {self.prop_scorer(self._test_demo_data_loader):.3f}")
            if save_to: self.save(save_to)

    def prop_scorer(self, dataloader):
        with torch.no_grad():
            total_proba, total_acc = [], []
            for batch in dataloader:
                state, action = batch['obs'].float().cuda(), batch['acts'].float().cuda()
                total_proba += self.predict_probability(state, action).tolist()
                total_acc += (np.argmax(self.forward(state), -1) == np.argmax(action, -1)).tolist()
            return {"pred_proba": float(np.mean(total_proba)), "accuracy": float(np.mean(total_acc))}
            
    def save(self, save_to):
        torch.save(self.behav_policy.state_dict(), os.path.join(save_to, "behav_policy.pt"))

    def load(self, save_path, latest = False):
        if latest == True:
            save_path = os.path.join(save_path, np.sort(os.listdir(save_path))[-1])
        self.behav_policy.load_state_dict(torch.load(os.path.join(save_path, 'behav_policy.pt')))


    
##########################################################################


def imitation_rollout_stats(venv, n_episodes=10):
    def scorer(policy):
        if venv is not None and n_episodes > 0:
            trajs = rollout.generate_trajectories(
                policy,
                venv,
                rollout.make_min_episodes(n_episodes),
            )
            return rollout.rollout_stats(trajs)
        else:
            return dict()
    return scorer


def ope_dm_scorer(policy, dataloader, reward_fn, env, lstm=False, no_transitions_eval=1024):
    samples = 0
    dm_reward, behav_reward, actual_rewards = [],[], []
    for batch in dataloader:
        observation, action, reward = batch['obs'], batch['acts'], batch['rews']
        
        actions_pred, _ = policy.predict(observation)
        state = env.recover_original_obs(observation.detach().cpu().numpy(), batch['infos'])
        
        state = torch.from_numpy(state).cuda().float()
        actions_pred = torch.from_numpy(actions_pred).cuda().float()
        pred_rew = reward_fn.forward(state, actions_pred).detach().cpu().numpy()
        pred_rew_behav = reward_fn.forward(state, action.cuda().float()).detach().cpu().numpy()
        
        if lstm:
            mask = batch['masks'].detach().to(torch.bool).cpu().numpy()
            state, reward, pred_rew, pred_rew_behav =  state[mask], reward[mask], pred_rew[mask], pred_rew_behav[mask]
        
        dm_reward.extend(pred_rew)
        behav_reward.extend(pred_rew_behav)
        actual_rewards.extend(reward)

        samples+=np.prod(reward.shape)
        if samples >= no_transitions_eval: break
    dm_reward =  np.mean(dm_reward)
    behav_reward = np.mean(behav_reward)
    return {'ope_dm_return': dm_reward}


def continuous_action_diff_scorer(policy, dataloader, input_name = "obs", lstm=False):
    total_diffs = []
    for batch in dataloader:
        input_output = base.get_bc_input_output(batch, input_name, device= policy.device, lstm=lstm)
        x = input_output[0].detach().cpu().numpy()
        actions = input_output[1].detach().cpu().numpy()
        actions_pred, _ = policy.predict(x)
        diff = (actions - actions_pred) ** 2
        
        if lstm:
            masks = input_output[2].detach().cpu().numpy()
            diff = diff[masks]
        total_diffs += diff.tolist()
    return {"cont_act_diff": float(np.mean(total_diffs))}


def discrete_action_diff_scorer(policy, dataloader, input_name = "obs", lstm=False):
    total_matches = []
    for batch in dataloader:
        input_output = base.get_bc_input_output(batch, input_name, device= policy.device, lstm=lstm)
        x = input_output[0].detach().cpu().numpy()
        actions = input_output[1].detach().cpu().numpy()
        actions_pred, _ = policy.predict(x)

        match = (actions == actions_pred).reshape(-1)
        # WHY NOT?: match = (np.argmax(batch.actions,dim=-1).reshape(-1) == np.argmax(actions,dim=-1).reshape(-1)).tolist()
            
        
        if lstm:
            masks = input_output[2].detach().cpu().numpy()
            match = match[masks]
        total_matches += match.tolist()
    return {"disc_act_acc": float(np.mean(total_matches))}


##########################################################################


def d3rlpy_evaluate_on_environment(eval_env, eval_episodes=10, mean=0, std=1):
    discrete_action = isinstance(eval_env.action_space, gym.spaces.Discrete)
 
    def scorer(policy, *args):
        episode_rewards = []
        for _ in range(eval_episodes):
            state, done = eval_env.reset(), False
            episode_reward = 0.0
            while not done:
                state = (np.array(state).reshape(1,-1) - mean)/std
                #action = policy.predict([state])[0] # CQL
                action = policy.predict(state)[0] # BCQ
                #action = policy.select_action(state)
                if discrete_action:
                    action = action.astype(int)
                state, reward, done, _ = eval_env.step(action)
                episode_reward += reward
            episode_rewards.append(episode_reward)
        return float(np.mean(episode_rewards)) #{"return_mean": }
        #d4rl_score = eval_env.get_normalized_score(avg_reward) * 100
    return scorer 


def _make_batches(
    episode, window_size: int, n_frames: int, infos=None
):
    n_batches = len(episode) // window_size
    if len(episode) % window_size != 0:
        n_batches += 1
    for i in range(n_batches):
        head_index = i * window_size
        last_index = min(head_index + window_size, len(episode))
        transitions = episode.transitions[head_index:last_index]
        batch = TransitionMiniBatch(transitions, n_frames)
        
        if infos is not None:
            info = infos[head_index:last_index]
            yield (batch, info)
        else:
            yield batch
        
def d3rlpy_discrete_action_match_scorer(
    algo, episodes
) -> float:
    r"""Returns percentage of identical actions between algorithm and dataset.

    This metrics suggests how different the greedy-policy is from the given
    episodes in discrete action-space.
    If the given episdoes are near-optimal, the large percentage would be
    better.

    .. math::

        \frac{1}{N} \sum^N \parallel
            \{a_t = \text{argmax}_a Q_\theta (s_t, a)\}

    Args:
        algo: algorithm.
        episodes: list of episodes.

    Returns:
        percentage of identical actions.

    """
    total_matches = []
    for episode in episodes:
        for batch in _make_batches(episode, WINDOW_SIZE, algo.n_frames):
            actions = algo.predict(batch.observations)
            match = (batch.actions.reshape(-1) == actions).tolist()
            # WHY NOT?: match = (np.argmax(batch.actions,dim=-1).reshape(-1) == np.argmax(actions,dim=-1).reshape(-1)).tolist()
            total_matches += match
    return float(np.mean(total_matches))



def d3rlpy_continuous_action_diff_scorer(algo, episodes) -> float:
    r"""Returns squared difference of actions between algorithm and dataset.

    This metrics suggests how different the greedy-policy is from the given
    episodes in continuous action-space.
    If the given episodes are near-optimal, the small action difference would
    be better.

    .. math::

        \mathbb{E}_{s_t, a_t \sim D} [(a_t - \pi_\phi (s_t))^2]

    Args:
        algo: algorithm.
        episodes: list of episodes.

    Returns:
        squared action difference.

    """
    total_diffs = []
    for episode in episodes:
        for batch in _make_batches(episode, WINDOW_SIZE, algo.n_frames):
            actions = algo.predict(batch.observations)
            diff = ((batch.actions - actions) ** 2).sum(axis=1).tolist()
            total_diffs += diff
    return float(np.mean(total_diffs))


def d3rlpy_ope_scorer(env, episode_infos, no_transitions_eval=1024):
    def scorer(algo, episodes):
        fqe = d3rlpy.ope.DiscreteFQE(algo, augmentation=episode_infos).fit(n_epochs=50) 

        samples = 0
        dm_reward, behav_reward, actual_rewards = [],[], [] 
        for episode, infos in zip(episodes, episode_infos):
            for (batch, info) in _make_batches(episode, WINDOW_SIZE, algo.n_frames, infos=infos):
                actions_pred = algo.predict(batch.observations)
                state = env.recover_original_obs(batch.observations, info)
        
                state = torch.from_numpy(state).cuda().float()
                actions_pred = torch.from_numpy(actions_pred).cuda().float()
                actions = torch.from_numpy(batch.actions).cuda().float()
                pred_rew = fqe.predict_value(state, actions_pred).detach().cpu().numpy()
                pred_rew_behav = fqe.predict_value(state, actions).detach().cpu().numpy()

                #pred_rew = reward_fn.forward(state, actions_pred) #.detach().cpu().numpy()
                #pred_rew_behav = reward_fn.forward(state, batch.actions) #.detach().cpu().numpy()
                
                #if lstm:
                #    mask = batch.masks #['masks'].detach().to(torch.bool).cpu().numpy()
                #    state, reward, pred_rew, pred_rew_behav =  state[mask], reward[mask], pred_rew[mask], pred_rew_behav[mask]
                
                dm_reward.extend(pred_rew)
                behav_reward.extend(pred_rew_behav)
                actual_rewards.extend(batch.rewards)

                samples+=np.prod(batch.rewards.shape)
            if samples >= no_transitions_eval: break
        dm_reward =  np.mean(dm_reward)
        behav_reward = np.mean(behav_reward)
        return float(dm_reward)
    return scorer