from tqdm.auto import tqdm

import wandb
import numpy as np

import torch
from torch.nn import functional

from utils import logs_handler, misc
from utils.ale_env import Env as AtariEnv

logger = logs_handler.get_logger(__name__)

class AtariEvaluator:
    
    def __init__(self, stage, history_length, max_episode_length=None, device=None, max_buffer_size=None, top_k=None):
        self.stage = stage
        self.device = device
        
        self.env: AtariEnv = None    
        self.model = None
        
        self.history_length = history_length
        self.max_episode_length = max_episode_length or int(108e3)
        self.max_buffer_size = max_buffer_size or int(128)
        self.top_k = top_k
        
        self.states = None
        self.actions = None
        self.buffer = None
        
        self.last_episode = {}

    def set_device(self, device):
        self.device = device
        
    def set_model(self, model):
        self.model = model.to(self.device)
    
    def set_env(self, env):
        self.env = env
    
    def reset_history(self, initial_state):
        self.states = [initial_state]
        self.actions = [len(self.env.actions)]
    
    def reset_buffer(self):
        self.buffer = [{}]
        
    def to_buffer(self, episode_id, key, value):
        if len(self.buffer) <= episode_id:
            self.buffer.append({})
            if len(self.buffer) > self.max_buffer_size:
                self.buffer = self.buffer[-self.max_buffer_size:]
            episode_id = -1
        if key not in self.buffer[episode_id]:
            self.buffer[episode_id][key] = []
        self.buffer[episode_id][key].append(value)
    
    def from_buffer(self, key, episode_id=None):
        if episode_id is not None:
            return self.buffer[episode_id][key]
        values = []
        for episode in self.buffer:
            values.append(episode[key])
        return values
    
    def get_all_actions(self, summary=False, discrete=True):
        all_actions = []
        for acts in self.from_buffer('actions'):
            all_actions.extend(acts)
        if summary and discrete:
            unique, counts = np.unique(all_actions, return_counts=True)
            actions_summary = list(zip(unique, counts))
            unique = set(unique)
            for k in self.env.actions.keys():
                if k not in unique:
                    actions_summary.append((k, 0))
            return actions_summary
        if summary:
            return np.histogram(all_actions)
        return np.array(all_actions).reshape(-1, 1)
    
    def get_all_returns(self):
        all_returns = []
        for rets in self.from_buffer('returns'):
            all_returns.append(rets[-1])
        return np.array(all_returns).reshape(-1, 1)
    
    def next_transition(self, action):
        state, reward, done = self.env.step(action)
        self.states.append(state)
        self.actions.append(action)
        if self.history_length is not None:
            self.states = self.states[-self.history_length:]
            self.actions = self.actions[-self.history_length:]
        return reward, done
                
    def process_transitions_history(self):
        states = torch.as_tensor(np.array(self.states), dtype=torch.float32, device=self.device)
        states = states.reshape(1, -1, *self.env.observation_shape())        
        actions = torch.as_tensor(np.array(self.actions), dtype=torch.long, device=self.device).reshape(1, -1)
        return states, actions
    
    def start_new_episode(self):
        inital_state = self.env.reset()
        self.reset_history(inital_state)
        self.last_episode = {'return': 0.0, 'length': 0, 'done': False}

    def evaluate_episode(self, episode_id=None, num_transitions=None, progress_bar=None, task_id=None):
        self.env.eval()
        self.model.eval()
        done = self.last_episode.get('done', True)
        if done:
            self.start_new_episode()
        episode_id = episode_id or (len(self.buffer) - 1)
        episode_return = self.last_episode['return']
        episode_length = self.last_episode['length']
        with torch.no_grad():
            for _ in range(num_transitions or self.max_episode_length):
                states, actions = self.process_transitions_history()
                outputs = self.model(stage=self.stage, states=states, actions=actions, top_k=self.top_k, sample=True)
                
                if self.model.action_discrete:
                    action = int(outputs.pop('actions')[0, -1].cpu())
                else:
                    action = float(outputs.pop('actions')[0, -1].cpu())
                if not self.env.isvalid_action(action):
                    action = self.env.random_action()
                certainty = float(outputs.pop('certainty')[0, -1].cpu())
                             
                reward, done = self.next_transition(action)
                episode_return += reward
                episode_length += 1

                self.to_buffer(episode_id, 'actions', action)                
                self.to_buffer(episode_id, 'certainty', certainty)
                self.to_buffer(episode_id, 'returns', episode_return)
                self.to_buffer(episode_id, 'done', done)

                if progress_bar is not None:
                    _id = f'{task_id} > ' if (task_id is not None) else ''
                    progress_bar.set_description(_id + f'return = {episode_return} '\
                        f'| timestep: {episode_length}/{self.max_episode_length}', refresh=True)
                if done:
                    break
            self.last_episode = {'return': episode_return, 'length': episode_length, 'done': done}
        return episode_return, episode_length

    def evaluate(self, num_trials, reduce_method=np.mean, epoch=None, progress=False, summary=False, use_id=False):
        _id = self.env.game + '_' if use_id else ''
        self.reset_buffer()
        self.start_new_episode()
        returns = []
        eval_return, eval_std = None, None
        pbar = tqdm(range(num_trials), disable=not progress)
        for trail_idx in pbar:
            episode_return, _ =\
                self.evaluate_episode(trail_idx, None, pbar, task_id=_id[:-1])
            returns.append(episode_return)
            eval_return = reduce_method(returns)
            if reduce_method == np.mean:
                eval_std = np.std(returns)
        
        log_msg = f'epoch: {epoch} | ' if (epoch is not None) else ''
        log_msg += f'{_id}eval return : {eval_return:.2f} '
        log_msg += f'| {_id}eval std: {eval_std:.2f} ' if eval_std else ''
        log_msg += f'| trails: {trail_idx + 1}/{num_trials}'
        logger.info(log_msg)
        wandb_logs = {f'{_id}eval_return': eval_return}
        if eval_std:
            wandb_logs.update({f'{_id}eval_std': eval_std})
        misc.wandb_log(wandb_logs)

        if summary and misc.wandb_ready():            
            actions_summary_table = wandb.Table(data=self.get_all_actions(summary=True, discrete=True), columns=['action', 'frequency'])
            returns_table = wandb.Table(data=self.get_all_returns(), columns=['return'])            

            actions_summary_plot = wandb.plot.bar(actions_summary_table, 'action', 'frequency', title='Action Distribution')
            returns_summary_plot = wandb.plot.histogram(returns_table, 'return', title='Returns Distribution')
            
            wandb.log({f'{_id}/action/bar': actions_summary_plot, f'{_id}/return/histogram': returns_summary_plot})
        return returns