import torch
import imageio
from .render_utils import initialize_viewer,render
import numpy as np

class base:
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 ppo = False,
                 **kwargs
                 ):

        self.max_path_length = kwargs['max_path_length']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device
        self.policy = None
        self.ppo = ppo

    def load_shared_network(self,policy_params):
        self.policy.shared_layer.load_state_dict(policy_params)
        for param in self.policy.shared_layer.parameters():
            param.requires_grad = False
    
    def evaluate_head(self,env):
        obs = env.reset()
        env_step = 0
        total_reward = 0
        while env_step<self.max_path_length:
            env_step += 1
            action = self.select_action(obs)
            obs, reward, done, env_info = env.step(action)
            total_reward += reward
            if done:
                break
        return total_reward,env_step
        
    def select_action(self,obs,deterministic=False):
        if self.ppo:
            if isinstance(obs, np.ndarray):
                obs = torch.tensor(obs, dtype=torch.float32).to(self.device)
            if obs.dim() == 1:
                obs = obs.unsqueeze(0)
            if deterministic:
                mean, _ = self.policy(obs)
                mean = torch.clamp(mean, -1, 1)
                return mean.squeeze(0).cpu().numpy()
            action, log_prob = self.policy.select_action(obs)
            value = self.value(obs)
            return action, log_prob.squeeze().cpu().numpy(), value.item()
        else:
            obs = torch.tensor(obs, dtype=torch.float32).to(self.device)
            action = self.policy(obs)
        return action.cpu().numpy()

    def collect_data_and_train_filter(self,env):
        raise NotImplementedError

    def evaluation(self,env,test_env=None,rendering=True,render_type='human'):
        with torch.no_grad():
            n_eval_epi = 10
            returns = 0
            if rendering:
                frames = []
                initialize_viewer(env,render_type)
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                o = env.reset()
                while env_step<self.max_path_length:
                    if rendering:
                        frame = render(env,test_env,render_type)
                        frames.append(frame)
                    env_step += 1
                    a = self.select_action(o,deterministic=True)
                    next_o, r, d, env_info = env.step(a)
                    episode_returns += r
                    o = next_o
                    if d:
                        break
                returns += episode_returns
            if rendering and (render_type == 'rgb_array'):
                imageio.mimsave('rollout.gif',frames,fps=30)
        return returns/n_eval_epi   
    