import random
from typing import Optional, List, Tuple, Union

import numpy as np
import seaborn as sns
import torch
from torch.nn import functional as F


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def gen_colors(num_colors: int) -> List[List[int]]:
    """
    Generate uniformly distributed `num_colors` colors
    """
    palette = sns.color_palette(None, num_colors)
    rgb_triples = [
        [int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)] 
        for x in palette
    ]
    return rgb_triples

def convert_to_gds_cell_format_batch(generated_layout, resolution=40000):
    """Convert model output to GDS cell format"""
    # Extract components
    B, N, _ = generated_layout.shape
    classes = generated_layout[:, :, 0].int()  # [N]
    
    normalized_coords = torch.clamp(generated_layout[:,:, 1:5], 0.0, 1.0)
    # Denormalize coordinates from [0,1] to GDS coordinates
    xc = normalized_coords[:, :, 0] * resolution
    yc = normalized_coords[:, :, 1] * resolution
    w = normalized_coords[:, :, 2] * resolution
    h = normalized_coords[:, :, 3] * resolution
    
    # Calculate corners
    half_w = w / 2
    half_h = h / 2
    
    # Create tensor for GDS cell format [layer, x1, y1, x2, y2, x3, y3, x4, y4]
    # N = len(generated_layout)
    gds_format = torch.zeros(B, N, 9, device=generated_layout.device)
    
    # Set layer type
    gds_format[:,:, 0] = classes
    
    # Set coordinates for 4 corners (rectangle) in clockwise order
    # Bottom-left
    gds_format[:, :, 1] = torch.clamp(xc - half_w, 0, resolution)
    gds_format[:, :, 2] = torch.clamp(yc - half_h, 0, resolution)
    
    # Top-left
    gds_format[:, :, 3] = torch.clamp(xc - half_w, 0, resolution)
    gds_format[:, :, 4] = torch.clamp(yc + half_h, 0, resolution)
    
    # Top-right
    gds_format[:, :, 5] = torch.clamp(xc + half_w, 0, resolution)
    gds_format[:, :, 6] = torch.clamp(yc + half_h, 0, resolution)
    
    # Bottom-right
    gds_format[:, :, 7] = torch.clamp(xc + half_w, 0, resolution)
    gds_format[:, :, 8] = torch.clamp(yc - half_h, 0, resolution)
    
    return gds_format.int()  # Convert to integers for GDS coordinates

@torch.no_grad()
def sample(
    model: torch.nn.Module,
    x: torch.Tensor,
    steps: int,
    inference_noise_level: float = 0.0    
) -> torch.Tensor:
    """
    Sample from the model autoregressively.

    Args:
        model: Model to sample from 
        x: Initial sequence tensor [batch_size, seq_len, dim]
        steps: Number of steps to sample
        inference_noise_level: Amount of noise to add during inference
    
    Returns:
        Extended sequence with sampled tokens [batch_size, steps, 5]
    """

    model.eval()
    batch_size, seq_len, dim = x.size()  # [batch_size, seq_len, 12]
    device = x.device

    # Track which sequences have finished
    finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=device)

    # Autoregressive sampling loop
    for i in range(steps - seq_len):
        # Break early if all sequences are finished
        if finished_sequences.all():
            break

        # Get indices of unfinished sequences
        unfinished_indices = torch.where(~finished_sequences)[0]

        # Only process unfinished sequences
        x_unfinished = x[unfinished_indices]

        # Get model predictions for next token
        processed_logits = model(x_unfinished)  # [unfinished_batch_size, seq_len, 5]
        next_token = processed_logits[:, -1, :].unsqueeze(1)  # [unfinished_batch_size, 1, 5]

        # Convert to one-hot format
        next_token_onehot = transfer_to_onehot(next_token)  # [unfinished_batch_size, 1, 12]

        # Create a padding token tensor for all sequences
        new_token = torch.zeros((batch_size, 1, dim), device=device)
        new_token[:, :, -1] = 1      # Set padding bit

        # Process each unfinished sequence
        for idx, orig_idx in enumerate(unfinished_indices):
            # Check for EOS token
            if next_token_onehot[idx, 0, -2] == 1.0:  # EOS token
                finished_sequences[orig_idx] = True
                continue

            # Validate coordinates (must be in [0,1] range)
            coords = next_token_onehot[idx, 0, 0:4]
            if not ((coords >= 0.0) & (coords <= 1.0)).all():
                finished_sequences[orig_idx] = True
                # Use EOS token instead of invalid coordinates
                next_token_onehot[idx, 0] = torch.zeros_like(next_token_onehot[idx, 0])
                next_token_onehot[idx, 0, -2] = 1.0  # Set EOS bit

            # Noise injection (inference)
            if inference_noise_level > 0:
                noise = torch.randn_like(next_token_onehot[idx, 0])
                next_token_onehot[idx, 0] = (1 - inference_noise_level) * next_token_onehot[idx, 0] + inference_noise_level * noise
            
            # Update the sequence with the new token
            new_token[orig_idx] = next_token_onehot[idx]

        # Append the new token to all sequences
        x = torch.cat((x, new_token), dim=1)

    # Convert final sequences to category format
    result = transfer_to_category(x)  # [batch_size, steps, 5]
    return result


@torch.no_grad()
def importance_sampling(model: torch.nn.Module, x: torch.Tensor, functions: dict, batch_size : int, steps: int) -> torch.Tensor:
    """
    Sample from the model autoregressively.

    Args:
        model: Model to sample from 
        x: Initial sequence tensor [seq_len, dim]
        fuctions: Dictionary of functions with weights for importance sampling
        batch_size: Size of the batch for importance sampling
        steps: Number of steps to sample
    
    Returns:
        Extended sequence with sampled tokens [batch_size, steps, 5]
    """

    model.eval()
    device = x.device

    # Track which sequences have finished
    # finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=device)

    initial_seq_len = x.size(0)  # [seq_len, dim]

    # Autoregressive sampling loop
    for i in range(steps - initial_seq_len):

        seq_len, dim = x.size()  # [seq_len, 10]
        x_repeat = x.repeat(batch_size, 1)
        x_repeat = x_repeat.reshape(batch_size,seq_len,dim)  # [batch_size, seq_len, dim]
        x_cell =  transfer_to_category(x_repeat)


        # Get model predictions for next token
        processed_logits = model(x_repeat)  # [unfinished_batch_size, seq_len, 5]
        next_token = processed_logits[:, -1, :].unsqueeze(1)  # [unfinished_batch_size, 1, 5]
        x_cell_next = torch.cat((x_cell, next_token), dim=1)  # [batch_size, seq_len+1, 5]

        best_next = torch.zeros((batch_size), device=device)
        for function, weight in functions.items():
            gds_next = convert_to_gds_cell_format_batch(x_cell_next)  # Convert to GDS cell format
            penalty = function(gds_next)
            best_next += weight * penalty
        
        best_idx = torch.argmin(best_next)  # Get the index of the best sequence
        x = transfer_to_onehot(x_cell_next[best_idx])
        print(f"Step {i+1}/{steps-initial_seq_len}, Layer : {next_token[best_idx,0,0]}, Best Index: {best_idx.item()}, Penalty: %.6f" % best_next[best_idx].item())
        # print(next_token[:, 0, 0])

        if torch.sum(next_token[:, 0, 0] == 4.0) > 0:  # EOS token
            break
        # x = torch.cat((x, new_token), dim=1)
    
    final_seq_len, dim= x.size()
    if final_seq_len < steps:
        # Pad the sequences to the required length
        padding = torch.zeros((steps - final_seq_len, dim), device=device)
        padding[:, -1] = 1
        x = torch.cat((x, padding), dim=0)

    # Convert final sequences to category format
    result = transfer_to_category(x)  # [batch_size, steps, 5]
    return result


def trim_tokens(tokens: torch.Tensor, bos: float = 5.0, eos: float = 6.0, pad: float = 7.0) -> torch.Tensor:
    categories = tokens[:, 0]

    mask = (categories != bos) & (categories != eos) & (categories != pad)

    trimmed_tokens = tokens[mask]
    return trimmed_tokens


def transfer_to_onehot(x: torch.Tensor) -> torch.Tensor:
    # input: [b, t, 5]
    # output: [b, t, 12]

    x_categories = x[..., 0].long()
    x_coords = x[..., 1:5]

    x_onehot = F.one_hot(x_categories, num_classes=6)

    result = torch.cat([x_coords, x_onehot.float()], dim=-1)

    return result


def transfer_to_category(x: torch.Tensor) -> torch.Tensor:
    # input: [b, t, 12]
    # output: [b, t, 5]

    x_coords = x[..., :4]
    x_onehot = x[..., 4:]

    x_categories = torch.argmax(x_onehot, dim=-1).float().unsqueeze(-1)

    result = torch.cat([x_categories, x_coords], dim=-1)

    return result



