import torch
import numpy as np
import subprocess
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import random


def set_seed(seed=42):
    # Python random
    random.seed(seed)

    # NumPy random
    np.random.seed(seed)

    # PyTorch random
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # PyTorch GPU settings
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def to_tensor_safe(x, dtype, device):
    if isinstance(x, torch.Tensor):
        return x.to(dtype=dtype, device=device)
    return torch.tensor(x, dtype=dtype, device=device)

def logit2onehot(action, action_dim=5):
    if isinstance(action, int):
        action = torch.tensor([action], dtype=torch.long)
    elif isinstance(action, np.ndarray):
        action = torch.from_numpy(action).long()
    elif isinstance(action, torch.Tensor):
        action = action.long()
    else:
        raise TypeError("Unsupported action type.")

    if action.dim() == 2 and action.shape[1] == 1:
        action = action.squeeze(1)

    if action.dim() == 0:
        one_hot = torch.zeros(action_dim, dtype=torch.float, device=action.device)
        one_hot[action] = 1
        return one_hot

    elif action.dim() == 1:
        batch_size = action.shape[0]
        one_hot = torch.zeros((batch_size, action_dim), dtype=torch.float, device=action.device)
        one_hot[torch.arange(batch_size), action] = 1
        return one_hot

    else:
        raise ValueError("Action tensor must be scalar or 1D/2D shape.")

def get_available_gpu():
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'],
            stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=True
        )
        memory_used = result.stdout.decode('utf-8').strip().split('\n')
        memory_used = [int(x) for x in memory_used]
        available_gpu = memory_used.index(min(memory_used))
        return f'cuda:{available_gpu}'
    except Exception:
        return 'cpu'

def tsne_visualization(policy_rep, policy_idx, file_name="policy_rep"):
    prey_rep = torch.as_tensor(policy_rep)
    prey_rep_np = prey_rep.detach().cpu().numpy().reshape(len(prey_rep), -1)
    
    tsne = TSNE(n_components=2, random_state=42, n_iter=2000)
    prey_rep_2d = tsne.fit_transform(prey_rep_np)

    prey_idx_flat = np.array(policy_idx).flatten()

    unique_labels = np.unique(prey_idx_flat)
    num_classes = len(unique_labels)

    label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    mapped_labels = np.vectorize(label_to_index.get)(prey_idx_flat)

    cmap = plt.cm.get_cmap('tab20', num_classes)
    norm = plt.Normalize(vmin=0, vmax=num_classes - 1)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(prey_rep_2d[:, 0], prey_rep_2d[:, 1], c=mapped_labels, cmap=cmap, norm=norm, s=30)
    cbar = plt.colorbar(scatter, ticks=range(num_classes))
    cbar.set_ticklabels(unique_labels)

    plt.title("t-SNE visualization of policy representation vectors")
    plt.xlabel("t-SNE dimension 1")
    plt.ylabel("t-SNE dimension 2")
    plt.savefig(f"{file_name}_t-SNE_visualization.png")

def policy_encoder_data_process(sample, device, max_len=25, min_len=5):
    def sample_generator(traj_samples, min_len, max_len):
        traj_list = []
        mask_list = []
        for traj in traj_samples:
            L = np.random.randint(min_len, max_len + 1) # choose the length of the trajectory to crop
            start_idx = np.random.randint(0, len(traj) - L + 1)
            crop = traj[start_idx:start_idx + L]
            pad_len = max_len - L
            if pad_len > 0:
                crop = np.concatenate([crop, np.zeros((pad_len, crop.shape[1]))], axis=0)
            mask = np.concatenate([np.ones(L), np.zeros(pad_len)], axis=0)

            traj_list.append(crop)
            mask_list.append(mask)
        
        # Stack and convert to tensor    
        traj_tensor = to_tensor_safe(np.stack(traj_list), torch.float, device)      # (B, max_len, feat_dim)
        mask_tensor = to_tensor_safe(np.stack(mask_list), torch.float, device)      # (B, max_len)
        return traj_tensor, mask_tensor
    
    traj_samples, contrast_label = sample
    traj_tensor_1, mask_tensor_1 = sample_generator(traj_samples, min_len, max_len)
    traj_tensor_2, mask_tensor_2 = sample_generator(traj_samples, min_len, max_len)
    traj_tensor = torch.cat((traj_tensor_1, traj_tensor_2), dim=0)
    mask_tensor = torch.cat((mask_tensor_1, mask_tensor_2), dim=0) 
    contrast_label = to_tensor_safe(contrast_label, torch.long, device)
    contrast_label = contrast_label
    return traj_tensor, contrast_label, mask_tensor

def z_sample(dis_z_dim, batch_size, device, n_dis_z=1):

    idx = np.zeros((n_dis_z, batch_size))
    if(n_dis_z != 0):
        dis_z = torch.zeros(batch_size, n_dis_z, dis_z_dim, device=device)
        
        for i in range(n_dis_z):
            idx[i] = np.random.randint(dis_z_dim, size=batch_size)
            dis_z[torch.arange(0, batch_size), i, idx[i]] = 1.0

        dis_z = dis_z.view(batch_size, -1, 1, 1)

    return dis_z, idx
