import json
import random
import math
from map.map import Map


class WorldGenerator:
    def __init__(self, config):
        self.config = config

    def generate_world(self, map: Map, input_world_path='data/progen_world.json', output_world_path='output/progen_world.json'):
        elements = self.sample_elements(map)
        self.write_elements_to_progen_world(elements, input_world_path, output_world_path)
        print(f'world generated {len(elements)} elements')

    def sample_elements(self, map):
        # 1. Read elements
        elements_bbox = self.load_elements_bbox(self.config['pysbench.bbox'])
        element_names = list(elements_bbox.keys())

        # 2. Get all normal nodes
        sidewalk_far_nodes = map.get_sidewalk_far_road_nodes()
        sidewalk_middle_nodes = map.get_sidewalk_middle_nodes()
        sidewalk_near_nodes = map.get_sidewalk_near_road_nodes()
        normal_nodes = map.get_normal_nodes()
        other_nodes = sidewalk_middle_nodes + sidewalk_near_nodes

        # 3. Tree element names
        tree_names = [name for name in element_names if name.startswith('BP_Tree')]
        other_names = [name for name in element_names if name not in tree_names]

        # 4. Calculate total number
        num_elements = self.config['pysbench.element_number_per_road'] * len(map.roads)
        num_trees = num_elements // 2
        num_others = num_elements - num_trees

        # 5. Sample tree nodes (far from road)
        tree_nodes = [node for node in sidewalk_far_nodes]
        sampled_tree_nodes = random.sample(tree_nodes, min(len(tree_nodes), num_trees))


        # 6. Sample other nodes
        other_candidate_nodes = [node for node in other_nodes if node not in sampled_tree_nodes]
        tmp_sampled_other_nodes = random.sample(other_candidate_nodes, min(len(other_candidate_nodes), num_others))

        used_normal_nodes = set()  
        sampled_other_nodes = []
        
        for node in tmp_sampled_other_nodes:
            if random.random() < 0.5:
                sampled_other_nodes.append(node)
            else:
                
                closest_normal_node = min(normal_nodes, 
                                        key=lambda x: x.position.distance(node.position))
                if closest_normal_node not in sampled_other_nodes:
                    sampled_other_nodes.append(closest_normal_node)
                    used_normal_nodes.add(closest_normal_node)
                else:
                    sampled_other_nodes.append(node)

        # 6.5 write sampled nodes back to map
        for node in sampled_tree_nodes:
            node.obstacle = True
        for node in sampled_other_nodes:
            node.obstacle = True

        # 7. Randomly select tree types and other types
        sampled_tree_names = random.choices(tree_names, k=len(sampled_tree_nodes))
        sampled_other_names = random.choices(other_names, k=len(sampled_other_nodes))

        # 8. Place elements
        placements = []
        for name, node in zip(sampled_tree_names, sampled_tree_nodes):
            placements.append({'element': name, 'node': node, 'bbox': elements_bbox[name]['bbox']})
        for name, node in zip(sampled_other_names, sampled_other_nodes):
            placements.append({'element': name, 'node': node, 'bbox': elements_bbox[name]['bbox']})

        return placements

    def load_elements_bbox(self, json_path):
        with open(json_path, 'r') as f:
            data = json.load(f)
        return data['elements']

    def write_elements_to_progen_world(self, elements, input_world_path, output_world_path):
        with open(input_world_path, 'r') as f:
            data = json.load(f)
        nodes = data.get('nodes', [])
        start_idx = len(nodes)
        for i, element in enumerate(elements):
            nodes.append(self.element_to_node_dict(element, start_idx + i))
        data['nodes'] = nodes
        with open(output_world_path, 'w') as f:
            json.dump(data, f, indent=2)

    def element_to_node_dict(self, element, idx):
        node = element['node']
        name = element['element']
        return {
            "id": f"GEN_{'_'.join(name.split('_')[:-1])}_{idx}",
            "instance_name": name,
            "properties": {
                "location": {
                    "x": node.position.x + random.uniform(-50, 50),
                    "y": node.position.y + random.uniform(-50, 50),
                    "z": 20
                },
                "orientation": {
                    "pitch": 0,
                    "yaw": random.uniform(0, 2 * math.pi),
                    "roll": 0
                },
                "scale": {
                    "x": 1.0,
                    "y": 1.0,
                    "z": 1.0
                }
            }
        }
