from typing import Any, Dict, List
import torch
import torch.distributed as dist
import numpy as np
import copy
import math
import gzip
import json
import os
#from instruct_manipulate import convert_keywords_to_uppercase, convert_to_mask

def extract_instruction_tokens(
    observations: List[Dict],
    instruction_sensor_uuid: str,
    tokens_uuid: str = "tokens",
    max_length: int = 512,
    pad_id: int = 0,
):
    """Extracts instruction tokens from an instruction sensor if the tokens
    exist and are in a dict structure."""
    if instruction_sensor_uuid not in observations[0]:
        return observations
    for i in range(len(observations)):
        if (
            isinstance(observations[i][instruction_sensor_uuid], dict)
            and tokens_uuid in observations[i][instruction_sensor_uuid]
        ):
            token = observations[i][instruction_sensor_uuid]["tokens"][:max_length]
            if len(token) < max_length:
                token += [pad_id] * (max_length - len(token))
            observations[i][instruction_sensor_uuid] = token
        else:
            break
    return observations


# here I tokenize the instruction text using BERT tokenizer
def tokenize_instruction(json_data,bert_vocab_path, type):
    """
    Load JSON.gz file, replace 'instruction_text' with 'instruction_text_type' 
    and save with '_type' suffix. Uses BERT-compatible tokenization.
    
    Args:
        json_data (str): Path to the input JSON.gz file
        type (str): The type suffix to add to instruction_text
    """
    # Load the JSON.gz file
    with gzip.open(json_data, 'rt') as f:
        data = json.load(f)
    
    with gzip.open(bert_vocab_path, 'rt') as f:
        berter = json.load(f)
    # Get BERT vocabulary for tokenization
    bert_vocab = berter.get('instruction_vocab', {})
    
    print(f"Using BERT tokenization with vocabulary size: {len(bert_vocab.get('word_list', []))}")
    print(f"UNK_INDEX: {bert_vocab.get('UNK_INDEX')}, PAD_INDEX: {bert_vocab.get('PAD_INDEX')}")
    
    # Process each episode
    processed_count = 0
    for episode in data['episodes']:
        if 'instruction' in episode and f'instruction_text_{type}' in episode['instruction']:
            # Replace instruction_text with instruction_text_type
            episode['instruction']['instruction_text'] = episode['instruction'][f'instruction_text_{type}']
            
            # Use BERT tokenizer for re-tokenization (handles padding internally)
            new_tokens = tokenizer_instruction_bert(
                episode['instruction']['instruction_text'], 
                bert_vocab, 
                max_length=80  # Standard BERT length
            )
            
            episode['instruction']['instruction_tokens'] = new_tokens
            
            # Clean up instruction dictionary - keep only instruction_text and instruction_tokens
            cleaned_instruction = {
                'instruction_text': episode['instruction']['instruction_text'],
                'instruction_tokens': episode['instruction']['instruction_tokens']
            }
            episode['instruction'] = cleaned_instruction
            processed_count += 1
    
    # Create output filename with '_type' suffix
    base_path = os.path.splitext(json_data)[0]  # Remove .gz
    if base_path.endswith('.json'):
        base_path = os.path.splitext(base_path)[0]  # Remove .json
    
    output_path = f"{base_path}_{type}.json.gz"
    
    # Save the modified data
    with gzip.open(output_path, 'wt') as f:
        json.dump(data, f)
    
    print(f"Processed {processed_count} episodes using BERT tokenization")
    print(f"Saved to: {output_path}")
    
    return output_path

def tokenize_instruction_old(json_data,bert_vocab_path, type='mask'):
    """
    Load JSON.gz file, replace 'instruction_text' with 'instruction_text_type' 
    and save with '_type' suffix. Uses BERT-compatible tokenization.
    
    Args:
        json_data (str): Path to the input JSON.gz file
        type (str): The type suffix to add to instruction_text
    """
    # Load the JSON.gz file
    with gzip.open(json_data, 'rt') as f:
        data = json.load(f)
    
    with gzip.open(bert_vocab_path, 'rt') as f:
        berter = json.load(f)
    # Get BERT vocabulary for tokenization
    bert_vocab = berter.get('instruction_vocab', {})
    
    print(f"Using BERT tokenization with vocabulary size: {len(bert_vocab.get('word_list', []))}")
    print(f"UNK_INDEX: {bert_vocab.get('UNK_INDEX')}, PAD_INDEX: {bert_vocab.get('PAD_INDEX')}")
    
    # Process each episode
    processed_count = 0
    for episode in data['episodes']:
        if 'instruction' in episode and f'instruction_text' in episode['instruction']:
            # Replace instruction_text with instruction_text_type
            converted_instruction = convert_keywords_to_uppercase(episode['instruction'][f'instruction_text'])
            if type == 'mask':
                masked, count = convert_to_mask(converted_instruction, masking_percentage=75)
                episode['instruction']['instruction_text'] = masked
            elif type == 'capitalize':
                episode['instruction']['instruction_text'] = converted_instruction
            
            # Use BERT tokenizer for re-tokenization (handles padding internally)
            new_tokens = tokenizer_instruction_bert(
                episode['instruction']['instruction_text'], 
                bert_vocab, 
                max_length=80  # Standard BERT length
            )
            
            episode['instruction']['instruction_tokens'] = new_tokens
            
            # Clean up instruction dictionary - keep only instruction_text and instruction_tokens
            cleaned_instruction = {
                'instruction_text': episode['instruction']['instruction_text'],
                'instruction_tokens': episode['instruction']['instruction_tokens']
            }
            episode['instruction'] = cleaned_instruction
            processed_count += 1
    
    # Create output filename with '_type' suffix
    base_path = os.path.splitext(json_data)[0]  # Remove .gz
    if base_path.endswith('.json'):
        base_path = os.path.splitext(base_path)[0]  # Remove .json
    
    output_path = f"{base_path}_{type}.json.gz"
    
    # Save the modified data
    with gzip.open(output_path, 'wt') as f:
        json.dump(data, f)
    
    print(f"Processed {processed_count} episodes using BERT tokenization")
    print(f"Saved to: {output_path}")
    
    return output_path

def tokenizer_instruction_bert(text, bert_vocab, max_length=80):
    """
    BERT-compatible tokenizer using word-level approach.
    
    Args:
        text (str): Input instruction text
        bert_vocab (dict): BERT vocabulary dictionary with word_list, PAD_INDEX, UNK_INDEX
        max_length (int): Maximum sequence length (default: 80)
    
    Returns:
        list: BERT-compatible token indices with [CLS], [SEP], and padding
    """
    import re
    
    # Get BERT vocabulary mapping
    word_list = bert_vocab['word_list']
    word2idx = {word: idx for idx, word in enumerate(word_list)}
    cls_token = word2idx['[CLS]']      # 101
    sep_token = word2idx['[SEP]']      # 102
    pad_token = bert_vocab['PAD_INDEX'] # 0
    unk_token = bert_vocab['UNK_INDEX'] # 100
    mask_token = word2idx['[MASK]']    # 103
    
    # First, handle special tokens by replacing them with placeholders
    special_tokens = ['[MASK]', '[CLS]', '[SEP]', '[PAD]', '[UNK]']
    text_with_placeholders = text
    special_token_map = {}
    
    for i, token in enumerate(special_tokens):
        if token in text_with_placeholders:
            placeholder = f"__SPECIAL_{i}__"
            special_token_map[placeholder] = token  # Keep original case for special tokens
            text_with_placeholders = text_with_placeholders.replace(token, f" {placeholder} ")
    
    # Tokenize with word-level approach (preserving case)
    tokens = re.findall(r'\w+|[^\w\s]', text_with_placeholders)
    
    # Map to BERT indices
    token_indices = []
    for token in tokens:
        # Check if it's a special token placeholder
        if token in special_token_map:
            original_token = special_token_map[token]
            token_indices.append(word2idx[original_token])
        elif token in word2idx:
            # Try exact match first (preserves case for uppercase keywords)
            token_indices.append(word2idx[token])
        elif token.lower() in word2idx:
            # Fallback to lowercase if exact match not found
            token_indices.append(word2idx[token.lower()])
        else:
            token_indices.append(unk_token)
    
    # Add BERT special tokens (CLS at start, SEP at end)
    bert_tokens = [cls_token] + token_indices + [sep_token]
    
    # Pad or truncate to max_length
    if len(bert_tokens) > max_length:
        bert_tokens = bert_tokens[:max_length-1] + [sep_token]
    else:
        bert_tokens.extend([pad_token] * (max_length - len(bert_tokens)))
    
    return bert_tokens

def gather_list_and_concat(list_of_nums,world_size):
    if not torch.is_tensor(list_of_nums):
        tensor = torch.Tensor(list_of_nums).cuda()
    else:
        if list_of_nums.is_cuda == False:
            tensor = list_of_nums.cuda()
        else:
            tensor = list_of_nums
    gather_t = [torch.ones_like(tensor) for _ in
                range(world_size)]
    dist.all_gather(gather_t, tensor)
    return gather_t

def dis_to_con(path, amount=0.25):
    starts = path[:-1]
    ends = path[1:]
    new_path = [path[0]]
    for s, e in zip(starts,ends):
        vec = np.array(e) - np.array(s)
        ratio = amount/np.linalg.norm(vec[[0,2]])
        unit = vec*ratio
        times = int(1/ratio)
        for i in range(times):
            if i != times - 1:
                location = np.array(new_path[-1])+unit
                new_path.append(location.tolist())
        new_path.append(e)
    
    return new_path

def get_camera_orientations12():
    base_angle_deg = 30
    base_angle_rad = math.pi / 6
    orient_dict = {}
    for k in range(1,12):
        orient_dict[str(base_angle_deg*k)] = [0.0, base_angle_rad*k, 0.0]
    return orient_dict
