from ortools.linear_solver import pywraplp
from typing import Dict, Any, Optional, List

def solve_schedule(plan_data: Dict[str, Any],
                   scene_info: Dict[str, Any],
                   agent_constraints: Optional[Dict] = None) -> Optional[Dict[str, Any]]:
    """
    Sets up and solves a Mixed-Integer Linear Programming (MILP) problem 
    to find an optimal schedule for the given tasks and agents.
    """


    if agent_constraints is None:
        agent_constraints = {}


    # Determine the earliest start time based on agent availability
    start_time_offset = min(
        constraint.get("unavailable_until", 0) for constraint in agent_constraints.values()
    ) if agent_constraints else 0
    
    print(f"\n[Optimizer] Running MILP solver... (Start offset: {start_time_offset:.2f}s)")
    

    # --- 1. Data Preparation ---
    # Extract tasks, agents, and dependencies from the plan data.
    tasks = plan_data.get("tasks", [])
    if not tasks:
        print("[Optimizer] No tasks to schedule.")
        return {"makespan": start_time_offset, "schedule": {}}

    agents = ["robot", "human"]
    dependencies = plan_data.get("dependencies", [])

    # Extract manipulation times and capabilities for each task.
    internal_duration_data = plan_data.get("mani_time", {}) 
    
    capability_flat = {}
    for task, perms in plan_data.get("capability", {}).items():
        for agent, can_do in perms.items():
            capability_flat[(task, agent)] = can_do

    locations_coords = {obj_id: (obj["target_position"]["x"], obj["target_position"]["z"]) 
                        for obj_id, obj in scene_info.get("navigation_targets", {}).items()}
    locations_coords["start_point"] = (0, 0)

    # Prepare navigation time data. The total duration of a task for the robot
    # will be its manipulation time plus the navigation time to get to the task location.
    task_start_location = plan_data.get("task_locations", {})
    task_end_location = {t: v.get("end") for t, v in task_start_location.items()}
    task_start_location = {t: v.get("start") for t, v in task_start_location.items()}
    task_start_location["initial_pos"] = "start_point"
    task_end_location["initial_pos"] = "start_point"


    nav_time_from_llm = plan_data.get("nav_time", {})
    nav_time = {}
    for key, value in nav_time_from_llm.items():
        parts = key.split(" -> ")
        if len(parts) == 2:
            nav_time[(parts[0], parts[1])] = value

    # Calculate total task duration for every possible transition (prev_task -> current_task)
    task_total_duration = {}
    tasks_with_initial = ["initial_pos"] + tasks
    for prev_task in tasks_with_initial:
        for current_task in tasks:
            prev_loc = task_end_location.get(prev_task)
            current_loc = task_start_location.get(current_task)
            if prev_loc is None or current_loc is None: continue
            
            navigation = nav_time.get((prev_loc, current_loc), 0)
            
            if current_task not in task_total_duration:
                task_total_duration[current_task] = {}
            if prev_loc not in task_total_duration[current_task]:
                task_total_duration[current_task][prev_loc] = {}


            duration_value = internal_duration_data.get(current_task, 0)

            if isinstance(duration_value, dict):

                for agent in agents:
                    duration = duration_value.get(agent, 0)
                    total_time = navigation + duration

                    task_total_duration[current_task][prev_loc][agent] = total_time
            else:

                duration = float(duration_value)
                total_time = navigation + duration
                for agent in agents:
                    task_total_duration[current_task][prev_loc][agent] = total_time


    # --- 2. Solver Initialization ---
    solver = pywraplp.Solver.CreateSolver("SCIP")
    if not solver: return None

    infinity = solver.infinity()
    
    # --- 3. Decision Variables ---
    # x[t, a]: Binary variable, 1 if task t is assigned to agent a, 0 otherwise.
    x = {(t, a): solver.IntVar(0, 1, f"x_{t}_{a}") for t in tasks for a in agents}
    
    # start_time[t], end_time[t]: Continuous variables for the start and end times of task t.
    start_time = {t: solver.NumVar(start_time_offset, infinity, f"start_{t}") for t in tasks}
    end_time = {t: solver.NumVar(start_time_offset, infinity, f"end_{t}") for t in tasks}
    
    # y[i, j, a]: Binary variable, 1 if agent a performs task j immediately after task i.
    # This is used for sequencing and preventing task overlap for a single agent.
    y = {(i, j, a): solver.IntVar(0, 1, f"y_{i}_{j}_{a}") for i in tasks_with_initial for j in tasks if i != j for a in agents}
    
    # M is a large constant used for 'Big M' method in MILP constraints.
    max_duration = 0
    for value in internal_duration_data.values():

        if isinstance(value, dict):

            for duration in value.values():
                if duration > max_duration:
                    max_duration = duration
        else:

            duration = float(value)
            if duration > max_duration:
                max_duration = duration
                
    M = max_duration * len(tasks) * 2 + sum(nav_time.values()) if max_duration > 0 else 10000
    
    # --- 4. Constraints ---
    # Each task must be assigned to exactly one agent.
    for t in tasks: solver.Add(sum(x[t, a] for a in agents) == 1)
    
    # A task can only be assigned to a capable agent.
    for (t, a), var in x.items():
        if capability_flat.get((t, a), 0) == 0: solver.Add(var == 0)
    
    # Precedence constraints based on task dependencies.
    for (t1, t2) in dependencies:
        if t1 in tasks and t2 in tasks: solver.Add(start_time[t2] >= end_time[t1])
    
    # Sequencing constraints to ensure an agent performs tasks one after another, not simultaneously.
    for a in agents:
        solver.Add(sum(y["initial_pos", j, a] for j in tasks) <= 1)
        for j in tasks:
            solver.Add(sum(y[i, j, a] for i in tasks_with_initial if i != j) == x[j, a])
            solver.Add(sum(y[j, k, a] for k in tasks if k != j) <= x[j, a])
    
    # Link start/end times with sequencing variables and total durations.
    for (i, j, a), var in y.items():
        loc_i_end = task_end_location.get(i)
        

        if loc_i_end is None or j not in task_total_duration or loc_i_end not in task_total_duration[j] or a not in task_total_duration[j][loc_i_end]:
            continue
        

        total_duration_for_agent = task_total_duration[j][loc_i_end][a]
        
        end_time_i = end_time.get(i) if i in tasks else start_time_offset
        
        solver.Add(start_time[j] >= end_time_i - M * (1 - var))
        solver.Add(end_time[j] >= start_time[j] + total_duration_for_agent - M * (1 - var))
        solver.Add(end_time[j] <= start_time[j] + total_duration_for_agent + M * (1 - var))
    




# Constraints for agent unavailability (e.g., due to a pause or during replanning).
    for agent, constraints in agent_constraints.items():
        if "unavailable_until" in constraints:
            unavailable_time = constraints["unavailable_until"]
            print(f"[Optimizer] Constraint added: Agent '{agent}' is unavailable until t={unavailable_time:.2f}s")
            for task in tasks:
                solver.Add(start_time[task] >= unavailable_time - M * (1 - x[task, agent]))
        
    # --- 5. Objective Function ---
    # Define the makespan as the completion time of the very last task.
    makespan = solver.NumVar(start_time_offset, infinity, "makespan")
    for t in tasks:
        solver.Add(makespan >= end_time[t])
    
    all_tasks = set(tasks)
    tasks_with_predecessors = {dep[1] for dep in dependencies}
    root_tasks = all_tasks - tasks_with_predecessors

    # Define workload imbalance (can be used as a secondary objective).
    total_human_workload = solver.NumVar(0, infinity, "total_human_workload")
    total_robot_workload = solver.NumVar(0, infinity, "total_robot_workload")

    human_tasks_durations = []
    robot_tasks_durations = []

    for t in tasks:
        duration_value = internal_duration_data.get(t, 0)
        

        if isinstance(duration_value, dict):
            human_duration = duration_value.get('human', 0)
            robot_duration = duration_value.get('robot', 0)
        else:
            human_duration = float(duration_value)
            robot_duration = float(duration_value)
        

        human_tasks_durations.append(human_duration * x[t, 'human'])
        robot_tasks_durations.append(robot_duration * x[t, 'robot'])


    solver.Add(total_human_workload == solver.Sum(human_tasks_durations))
    solver.Add(total_robot_workload == solver.Sum(robot_tasks_durations))


    imbalance = solver.NumVar(0, infinity, 'imbalance')


    solver.Add(imbalance >= total_human_workload - total_robot_workload)
    solver.Add(imbalance >= total_robot_workload - total_human_workload)


    # The primary objective is to minimize the makespan.
    # A penalty for imbalance can be added here if needed.
    imbalance_penalty = 0.0

    solver.Minimize(makespan + imbalance_penalty * imbalance)


    # --- 6. Solve and Return Results ---
    # Set a time limit for the solver. If it can't find the optimal solution
    solver.set_time_limit(30000)
    status = solver.Solve()
    status_map = {0: 'OPTIMAL', 1: 'FEASIBLE', 2: 'INFEASIBLE', 3: 'UNBOUNDED', 4: 'ABNORMAL', 6: 'NOT_SOLVED'}

    # If a solution is found, format it into a dictionary and return.
    if status == pywraplp.Solver.OPTIMAL or status == pywraplp.Solver.FEASIBLE:
        print(f"[Optimizer] Solution found (Status: {status_map.get(status, 'UNKNOWN')}).")
        schedule_output = {}
        for t in tasks:
            assigned_agent = ""
            for a in agents:
                if x[t, a].solution_value() > 0.5:
                    assigned_agent = a
                    break
            schedule_output[t] = { "agent": assigned_agent, "start": start_time[t].solution_value(), "end": end_time[t].solution_value() }
        return { "makespan": makespan.solution_value(), "schedule": schedule_output }
    else:
        print(f"[Optimizer] No solution found. Status: {status_map.get(status, 'UNKNOWN')}")
        return None