import math
import numpy as np

import torch
import gpytorch
from gp_model import VariationalMultitaskGPModel


def gp_sampled(cfg, sa_grid):
    model = VariationalMultitaskGPModel(
        kernel=cfg.kernel,
        input_dim=cfg.state_dim + cfg.action_dim,
        num_models=cfg.state_dim + 1,
        num_inducing_points=cfg.num_inducing_points,
        use_coregionalization=True,
        overwrite_lmc_coeffs=True,
    ).to(cfg.device)
    model.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        f_true = model(sa_grid).sample()
    rewards = f_true[:, -1]
    normalized_rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min())
    f_true_values = {"next_state": f_true[:, :-1], "reward": normalized_rewards}
    return f_true_values

def navigation(cfg, sa_grid):
    s, a = sa_grid[:, :2], sa_grid[:, 2:]
    goal_state = torch.tensor([0.9, 0.9], device=cfg.device)
    goal_threshold = 0.1

    next_state = torch.clamp(s + a, 0, 1)
    distance_to_goal = torch.norm(next_state - goal_state, dim=1)
    reward = torch.full((s.shape[0],), -0.01, device=cfg.device)
    reward[distance_to_goal < goal_threshold] = 1.0

    f_true_values = {"next_state": next_state, "reward": reward}
    return f_true_values

def maze(cfg, sa_grid):
    s, a = sa_grid[:, :2], sa_grid[:, 2:]
    maze_map = torch.ones((12, 12), device=cfg.device)
    
    # U-shaped walkable area
    maze_map[2:10, 2:4] = 0
    maze_map[2:4, 2:10] = 0
    maze_map[8:10, 2:10] = 0
    maze_map[2:10, 8:10] = 0
    
    def continuous_to_grid(pos):
        grid_pos = (pos * (maze_map.shape[0] - 1)).long()
        return torch.clamp(grid_pos, 0, maze_map.shape[0] - 1)
    
    def is_in_wall(pos):
        grid_indices = continuous_to_grid(pos)
        return maze_map[grid_indices[:, 0], grid_indices[:, 1]] == 1
    
    goal_state = torch.tensor([9/12, 9/12], device=cfg.device)
    
    next_s = torch.clamp(s + a, 0, 1)
    in_wall = is_in_wall(next_s)
    next_state = next_s.clone()
    next_state[in_wall] = s[in_wall]
    
    distance_to_goal = torch.norm(next_state - goal_state, dim=1)
    reward = torch.zeros(s.shape[0], device=cfg.device)
    reward[distance_to_goal < 0.05] = 1.0
    
    f_true_values = {"next_state": next_state, "reward": reward}
    return f_true_values