import gym
import numpy as np
from gym import spaces
import torch
import os

def reward_maze2d_function_g(target):
    def reward_maze2d(next_obs):
        rew = torch.exp(-torch.linalg.norm(next_obs[:,:2] - target, axis=1))
        done = (torch.linalg.norm(next_obs[:,:2] - target, axis=1) <= 0.5).to(torch.float32)
        return torch.cat([rew.unsqueeze(-1), done.unsqueeze(-1)], dim=1)
    return reward_maze2d

def termination_fn_halfcheetah(next_obs, use_torch=False):
    if use_torch:
        not_done = torch.logical_and(
                torch.all(next_obs > -100, dim=-1),
                torch.all(next_obs < 100, dim=-1)
            )
        done = ~not_done
        done = done.unsqueeze(-1)
        return done
    else:
        not_done = np.logical_and(np.all(next_obs > -100, axis=-1), np.all(next_obs < 100, axis=-1))
        done = ~not_done
        done = done[:, None]
        return done

def termination_fn_hopper(next_obs, use_torch=False):

    if use_torch:
        height = next_obs[:, 0]
        angle = next_obs[:, 1]

        obs_limit = (torch.abs(next_obs[:, 1:]) < 100).all(dim=-1)
        height_limit = height > 0.7
        angle_limit = torch.abs(angle) < 0.2
        not_done = obs_limit & height_limit & angle_limit

        done = torch.logical_not(not_done)
        done = done.unsqueeze(-1)
        return done
    else:
        height = next_obs[:, 0]
        angle = next_obs[:, 1]
        not_done =  np.isfinite(next_obs).all(axis=-1) \
                        * (np.abs(next_obs[:,1:]) < 10).all(axis=-1) \
                        * (height > .7) \
                        * (np.abs(angle) < .2)

        done = ~not_done
        done = done[:,None]
        return done

def termination_fn_walker2d(next_obs, use_torch=False):

    if use_torch:
        height = next_obs[:, 0]
        angle = next_obs[:, 1]
        obs_limit = (torch.abs(next_obs) < 100).all(dim=-1)
        height_limit = (height > 0.8) & (height < 2.0)
        angle_limit = (angle > -1.0) & (angle < 1.0)
        not_done = obs_limit & height_limit & angle_limit
        done = torch.logical_not(not_done)
        done = done.unsqueeze(-1)
        return done

    else:
        height = next_obs[:, 0]
        angle = next_obs[:, 1]
        not_done =  np.logical_and(np.all(next_obs > -100, axis=-1), np.all(next_obs < 100, axis=-1)) \
                    * (height > 0.8) \
                    * (height < 2.0) \
                    * (angle > -1.0) \
                    * (angle < 1.0)
        done = ~not_done
        done = done[:,None]
        return done

class DiffEnv(gym.Env):
    def __init__(self, dataset, diffuser, rw_model, env_name, reward_norm = False, target=None):
        self.unnormalize = False
        observation_dim = dataset.observation_dim
        action_dim = dataset.action_dim
        self.observation_space = spaces.Box(-1, 1, shape=(observation_dim,))
        self.action_space = spaces.Box(-1, 1, shape=(action_dim,))
        self.dataset = dataset
        self.dim_max = torch.tensor(self.dataset.dim_max, dtype=torch.float32, device="cuda")
        self.dim_min = torch.tensor(self.dataset.dim_min, dtype=torch.float32, device="cuda")
        self.rew_max = torch.tensor(self.dataset.rew_max, dtype=torch.float32, device="cuda")
        self.rew_min = torch.tensor(self.dataset.rew_min, dtype=torch.float32, device="cuda")

        self.diffuser = diffuser
        self.rw_model = rw_model
        self.current_state = self.dataset.samplings0()
        self.obs = self.get_observations()

        self.reward_norm = reward_norm
        self.path = os.path.join(os.getcwd(), env_name)

        self.diffuser.load(os.path.join(self.path, f'{env_name}_diffusion_model.pt'))
        if "maze" not in env_name:
            self.rw_model.load(os.path.join(self.path, f'{env_name}_reward_model.pt'))
        else:
            target = torch.tensor(target, dtype=torch.float32, device="cuda")
            self.rw_model = reward_maze2d_function_g(target)

        if "halfcheetah" in env_name:
            self.terminal_fn = termination_fn_halfcheetah
        elif "walker2d" in env_name:
            self.terminal_fn = termination_fn_walker2d
        elif "hopper" in env_name:
            self.terminal_fn = termination_fn_hopper
        else:
            self.terminal_fn = None

        self.t = 0

    def setstate(self, state):
        self.current_state = state

    def reset(self):
        self.t = 0
        state = self.dataset.samplings0()
        self.current_state = state
        return state

    def set_unnormalize(self):
        self.unnormalize = True

    def reward_test(self, obs0, action, obs1):
        conditions = np.concatenate([obs0, action, obs1], axis=0)
        reward_input = torch.tensor(conditions, device='cuda').unsqueeze(0)
        reward_done = self.rw_model(reward_input)
        reward = reward_done.detach().to("cpu").squeeze(0).numpy()
        reward = reward[0]
        return reward

    def original_step(self, action, apply_noise=None):
        conditions = np.concatenate([self.current_state, action], axis=0)

        with torch.no_grad():
            state = self.diffuser(torch.tensor(conditions, device='cuda').unsqueeze(0), apply_noise=apply_noise)
            reward_input = torch.cat((torch.tensor(conditions, device='cuda').unsqueeze(0), state),dim=1)
            reward_done = self.rw_model(reward_input)
        reward_done = reward_done.detach().to("cpu").squeeze(0).numpy()
        reward = reward_done[0]
        if self.reward_norm:
            reward *= self.reward_norm
        done = False
        state = state.detach().to('cpu').squeeze(0).numpy()
        timeout = False
        self.current_state = state

        return state, reward, done, {"diff_timeout":timeout}

    def step(self, obss, actions, std_threshold=0.01, repeated=5):
        normalizer = self.dataset.normalizer
        obss = normalizer['observations'].normalize(obss)
        actions = normalizer['actions'].normalize(actions)
        conditions = np.concatenate((obss, actions), axis=1)
        conditions = torch.tensor(conditions).to(device="cuda")

        with torch.no_grad():
            if std_threshold:
                states = []
                for i in range(repeated):
                    state = self.diffuser(conditions)
                    states.append(state.unsqueeze(0))
                states = torch.cat(states, dim=0)
                state_std, state = torch.std_mean(states, dim=0)
                state_std = torch.mean(state_std, dim=1)
                reward_input = torch.cat((conditions, state), dim=1)
                reward_done = self.rw_model(reward_input)
            else:
                state = self.diffuser(conditions)
                reward_input = torch.cat((conditions, state), dim=1)
                reward_done = self.rw_model(reward_input)
            # reward_done = reward_done.detach().to("cpu").numpy()
            reward = reward_done[:, 0]
            if self.reward_norm:
                reward *= self.reward_norm
            next_obs = normalizer['observations'].unnormalize_torch(state)
            if self.terminal_fn is not None:
                done = self.terminal_fn(next_obs, use_torch=True)
                done = done.squeeze(-1)

            else:
                done = reward_done[:, -1] > 0.5

            if std_threshold:

                reward = self.rew_min * (state_std >= std_threshold) + (
                            state_std < std_threshold) * reward
            state = normalizer['observations'].unnormalize_torch(state)
            state = state.detach().to("cpu").numpy()
            reward = reward.unsqueeze(-1).detach().to("cpu").numpy()
            done = done.unsqueeze(-1).detach().to("cpu").numpy()

        return state, reward, done, None

    # torch-based
    def multi_step(self, obss, actions, std_threshold=0, repeated=5):
        normalizer_obs = self.dataset.normalizer['observations']
        normalizer_act = self.dataset.normalizer['actions']
        actions = normalizer_act.normalize_torch(actions)
        conditions = torch.cat([obss, actions], dim=1)
        with torch.no_grad():
            if std_threshold:
                states = []
                for i in range(repeated):
                    state = self.diffuser(conditions)
                    states.append(state.unsqueeze(0))
                states = torch.cat(states, dim=0)
                state_std, state = torch.std_mean(states, dim=0)
                state_std = torch.mean(state_std, dim=1)
                reward_input = torch.cat((conditions, state), dim=1)
                reward_done = self.rw_model(reward_input)
            else:
                state = self.diffuser(conditions)
                reward_input = torch.cat((conditions, state), dim=1)
                reward_done = self.rw_model(reward_input)
            reward = reward_done[:, 0]
            if self.reward_norm:
                reward *= self.reward_norm
            next_obs = normalizer_obs.unnormalize_torch(state)

            if self.terminal_fn is not None:
                done = self.terminal_fn(next_obs, use_torch=True)
                done = done.squeeze(-1)

            else:
                done = reward_done[:, -1] > 0.5

            if std_threshold:
                done = done | (state_std >= std_threshold)
                reward = - state_std * 50 + reward
        '''
        probably add reward normalization
        '''

        return state, reward, done, None

    def get_observations(self):
        return self.dataset.get_observations()

    def update_on_weights(self, weights):
        t = 1
        pass
    def render(self, mode="human"):
        pass

    def close(self):
        pass
