import numpy as np
import torch
import torch.nn.functional as F

    
def state_trans(env, player_type, state):
    if env in ['push', 'tag', 'spread']:
        obs = state['agent_0']
    else: 
        obs = state[0]
    
    return obs if isinstance(obs, torch.Tensor) else torch.tensor(obs, dtype=torch.float32)
        
        
def select_action_with_mask(action, log_probs, info, use_random=True, action_dim=7):
    action = action.item()
    mask = torch.tensor(info['action_mask'], dtype=torch.bool)
    if not info['done'] and use_random:
        return np.random.randint(0, action_dim)
    if mask[action]:
        return action
    masked_log_probs = log_probs.clone()
    masked_log_probs[~mask] = float('-inf')
    final_action = torch.argmax(masked_log_probs).item()

    return final_action


def action_one_hot(action, num_classes=11, epsilon=1e-6):
    """
    x: Tensor of shape [batch_size], with discrete class indices
    Returns: Tensor of shape [batch_size, num_classes], with smoothed one-hot
    """
    action = action if isinstance(action, torch.Tensor) else torch.tensor(action, dtype=torch.long)
    one_hot = F.one_hot(action, num_classes).float()
    act = one_hot + epsilon  # all positions now >= epsilon
    act = F.layer_norm(act, normalized_shape=(num_classes,))
    return act.squeeze(0)


def badminton_process_state(state):
    """
    Args:
        state (tuple): (int, (x1, y1), (x2, y2), (x3, y3))
        num_classes (int): one-hot nclass

    Returns:
        torch.Tensor
    """
    # one-hot
    first_elem_onehot = action_one_hot(state[0])

    # (x / -177.5, (240 - y) / 240)
    # self coord
    second_coord_tensor = torch.tensor([
        state[1][0] / -177.5,
        (240 + state[1][1]) / 240
    ], dtype=torch.float32)

    # (x / 177.5, (y + 240) / 240)
    # opponent coord
    third_coord_tensor = torch.tensor([
        state[2][0] / 177.5,
        (240 - state[2][1]) / 240
    ], dtype=torch.float32)

    # (x / 177.5, (y + 240) / 240)
    # opponent ball landing
    fourth_coord_tensor = torch.tensor([
        state[3][0] / 177.5,
        (240 - state[3][1]) / 240
    ], dtype=torch.float32)

    final_tensor = torch.cat([
        first_elem_onehot,
        second_coord_tensor,
        third_coord_tensor,
        fourth_coord_tensor
    ])

    return final_tensor


def badminton_action_process(logits, info, state):
    """
    Args:
        vec (torch.Tensor): [16] or [1, 16]

    Returns:
        Tuple[int, Tuple[float, float], Tuple[float, float]]: (index, (x1, y1), (x2, y2))
    """
    
    
    if logits.dim() == 2:
        logits = logits.squeeze(0)
    act = logits[:11]
    probs = F.softmax(act, dim=0)
    probs = tuple(probs.cpu().numpy()[1:]) # [1:-1]
    action = torch.argmax(act, dim=0).item()
    if action == 0:
        action += 1
    elif action == 12:
        action -= 1

    coords = logits[11:]
    x1, y1, x2, y2 = coords
    x1 = x1.item() * 177.5
    y1 = -(y1.item() * 240) + 240

    x2 = -x2.item() * 177.5
    y2 = (y2.item() * 240) - 240
    
    if info['round'][-1] == 1 and state[2][0] * x1 < 0:
        x1 = x1 * -1

    return (action, (x1, y1), (x2, y2), probs)