import random
from typing import Optional, List, Tuple, Union

import numpy as np
import seaborn as sns
import torch
from torch.nn import functional as F


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def gen_colors(num_colors: int) -> List[List[int]]:
    """
    Generate uniformly distributed `num_colors` colors
    """
    palette = sns.color_palette(None, num_colors)
    rgb_triples = [
        [int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)] 
        for x in palette
    ]
    return rgb_triples


@torch.no_grad()
def sample(model: torch.nn.Module, x: torch.Tensor, steps: int) -> torch.Tensor:
    """
    Sample from the model autoregressively.

    Args:
        model: Model to sample from 
        x: Initial sequence tensor [batch_size, seq_len, dim]
        steps: Number of steps to sample
    
    Returns:
        Extended sequence with sampled tokens [batch_size, steps, 5]
    """

    model.eval()
    batch_size, seq_len, dim = x.size()  # [batch_size, seq_len, 12]
    device = x.device

    # Track which sequences have finished
    finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=device)

    # Autoregressive sampling loop
    for i in range(steps - seq_len):
        # Break early if all sequences are finished
        if finished_sequences.all():
            break

        # Get indices of unfinished sequences
        unfinished_indices = torch.where(~finished_sequences)[0]

        # Only process unfinished sequences
        x_unfinished = x[unfinished_indices]

        # Get model predictions for next token
        processed_logits = model(x_unfinished)  # [unfinished_batch_size, seq_len, 5]
        next_token = processed_logits[:, -1, :].unsqueeze(1)  # [unfinished_batch_size, 1, 5]

        # Convert to one-hot format
        next_token_onehot = transfer_to_onehot(next_token)  # [unfinished_batch_size, 1, 12]

        # Create a padding token tensor for all sequences
        new_token = torch.zeros((batch_size, 1, dim), device=device)
        new_token[:, :, -1] = 1      # Set padding bit

        # Process each unfinished sequence
        for idx, orig_idx in enumerate(unfinished_indices):
            # Check for EOS token
            if next_token_onehot[idx, 0, -2] == 1.0:  # EOS token
                finished_sequences[orig_idx] = True
                continue

            # Validate coordinates (must be in [0,1] range)
            coords = next_token_onehot[idx, 0, 0:4]
            if not ((coords >= 0.0) & (coords <= 1.0)).all():
                finished_sequences[orig_idx] = True
                # Use EOS token instead of invalid coordinates
                next_token_onehot[idx, 0] = torch.zeros_like(next_token_onehot[idx, 0])
                next_token_onehot[idx, 0, -2] = 1.0  # Set EOS bit
            
            # Update the sequence with the new token
            new_token[orig_idx] = next_token_onehot[idx]

        # Append the new token to all sequences
        x = torch.cat((x, new_token), dim=1)

    # Convert final sequences to category format
    result = transfer_to_category(x)  # [batch_size, steps, 5]
    return result


def trim_tokens(tokens: torch.Tensor, bos: float = 5.0, eos: float = 6.0, pad: float = 7.0) -> torch.Tensor:
    categories = tokens[:, 0]

    mask = (categories != bos) & (categories != eos) & (categories != pad)

    trimmed_tokens = tokens[mask]
    return trimmed_tokens


def transfer_to_onehot(x: torch.Tensor) -> torch.Tensor:
    # input: [b, t, 5]
    # output: [b, t, 12]

    x_categories = x[..., 0].long()
    x_coords = x[..., 1:5]

    x_onehot = F.one_hot(x_categories, num_classes=8)

    result = torch.cat([x_coords, x_onehot.float()], dim=-1)

    return result


def transfer_to_category(x: torch.Tensor) -> torch.Tensor:
    # input: [b, t, 12]
    # output: [b, t, 5]

    x_coords = x[..., :4]
    x_onehot = x[..., 4:]

    x_categories = torch.argmax(x_onehot, dim=-1).float().unsqueeze(-1)

    result = torch.cat([x_categories, x_coords], dim=-1)

    return result
