import torch
import numpy as np
import io
import contextlib
import warnings
import d4rl
import d4rl_ext

from PIL import Image

from contextlib import redirect_stdout
from jaxrl_m.evaluation import kitchen_render

import matplotlib 
import matplotlib.pyplot as plt

def convert_to_pytorch(v):
    if isinstance(v, int):
        return torch.tensor(v).long()
    elif isinstance(v, float):
        return torch.tensor(v).float()
    elif isinstance(v, bool):
        return torch.tensor(v).bool()
    elif isinstance(v, list):
        return torch.tensor(v).float()
    elif isinstance(v, torch.Tensor):
        return v
    elif isinstance(v, np.ndarray):
        if v.dtype == "float64":  # always convert double to float
            return torch.from_numpy(v).float()
        else:
            return torch.from_numpy(v)
    elif isinstance(v, dict):
        return {k: convert_to_pytorch(_v) for k, _v in v.items()}


def concat_frames(frames):
    results = {}
    for k in frames[0]:
        t = [f[k].unsqueeze(0) for f in frames]
        t = torch.cat(t, dim=0)
        results[k] = t
    return results

class Bot:
    def reset(self, seed, **args):
        raise NotImplementedError

    def _action(self, frame,**args):
        raise NotImplementedError

class PytorchD4RLGymEnv:
    def __init__(self, create_env_fn, **args):
        id = args['id']
        print(id)
        self.visual = 'topview' in id
        if self.visual:
            id = '-'.join(id.split('-')[1:])
        buffer = io.StringIO()
        with contextlib.redirect_stdout(buffer):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                self._env = create_env_fn(id)
        self._observation = None
        self._goal = torch.tensor([0.0])
        self._render = False

    def reset(self, seed, verbose=False) -> dict:
        if hasattr(self._env.env, 'env') and 'kitchen' in str(type(self._env.env.env)):
            tasks = self._env.TASK_ELEMENTS
            obs_element_indices = d4rl.kitchen.kitchen_envs.OBS_ELEMENT_INDICES
            # mask = np.concatenate([list(range(11))] +[obs_element_indices[task] for task in tasks])
    
        if verbose:  
            o = self._env.reset()
        else:
            buffer = io.StringIO()
            with contextlib.redirect_stdout(buffer):
                o = self._env.reset()
        self._env.seed(seed)
        self._observation = convert_to_pytorch(o)
        if hasattr(self._env, 'target_goal'):
            self._goal = torch.tensor(self._env.target_goal)
            #print(f"{self._goal.numpy()}, {self._env.target_goal}")
            if verbose:
                self._env.set_target(self._goal.numpy())
            else:
                with open('/dev/null', 'w') as f:
                    with redirect_stdout(f):
                        self._env.set_target(self._goal.numpy())
        if hasattr(self._env.env, 'env') and 'kitchen' in str(type(self._env.env.env)):
            return {"observation": self._observation[mask], "goal": self._observation[30:][mask]}
        else:
            return {"observation": self._observation, "goal": self._goal}

    def step(self, action) -> dict:
        observation, reward, done, info = self._env.step(action)
        if self._render:
            self._env.render()
        truncated = torch.tensor(info.get("TimeLimit.truncated", False))  # WARNING truncation ignored if this dict key is absent
        observation = convert_to_pytorch(observation)
        action = convert_to_pytorch(action)
        self._observation = observation  # WARNING MOVED THIS UP r creation to have updated observation in frame
        # this probably does not make the frame valid, but it allows the model to receive the correct observation

        if hasattr(self._env.env, 'env') and 'kitchen' in str(type(self._env.env.env)):
            tasks = self._env.TASK_ELEMENTS
            obs_element_indices = d4rl.kitchen.kitchen_envs.OBS_ELEMENT_INDICES
            # mask = np.concatenate([list(range(11))] + [obs_element_indices[task] for task in tasks])

            r = {"observation": self._observation[:30], "goal": self._observation[30:], "action": action, "terminated": torch.tensor(done),
             "truncated": truncated, "reward": torch.tensor(reward)}
        else:
            r = {"observation": self._observation, "goal": self._goal, "action": action, "terminated": torch.tensor(done),
             "truncated": truncated, "reward": torch.tensor(reward)}
        return r

    def gather_episode(self, seed, bot, bot_args={}):
        assert isinstance(bot, Bot)
        frame = self.reset(seed)
        bot.reset(seed)
        done = False

        frames = []
        i = 0
        images = []
        while not done:
            i += 1
            action = bot._action(frame, **bot_args)
            frame = self.step(action)
            frames.append(frame)
            done = frame["terminated"] or frame["truncated"]
            # if i % 10:
            #     images.append(Image.fromarray(np.array(kitchen_render(self._env))))
                
            # if bot.current_phase == 2:
                # print(frame['observation'], frame['goal'], frame['reward'])
        # input()
        # images[0].save('kitchen.gif', save_all=True, append_images=images[1:], duration=10, loop=0)
        

        episode_frames = concat_frames(frames)
        if hasattr(self._env, 'get_normalized_score'):
            ret = episode_frames["reward"].sum().numpy()
            #print(f"og r: {ret}, norm r: {self._env.get_normalized_score(ret) * 100}")
            episode_frames["normalized_score"] = self._env.get_normalized_score(ret) * 100
        return episode_frames
    
    def set_render(self, render):
        self._render = render