import os
import csv
import hydra
import numpy as np
import wandb
from omegaconf import DictConfig
from ortools.linear_solver import pywraplp
from tqdm import tqdm

from utils.experiment import setup_wandb, get_data


def solve_equation1_ortools(data):
    """
    Solves the Uniform Facility Location problem (Equation 1) using Google OR-Tools.
    Minimize: Sum(d(x, F)) + |F|

    Args:
        data (torch_geometric.data.Data): Input graph.
                                          edge_index represents valid connections (d <= 1).
                                          edge_weight represents distances.

    Returns:
        dict: Optimal objective value and solution details.
    """
    num_nodes = data.num_nodes

    # --- 1. Setup Valid Connections ---
    # Dictionary to store valid edges: (client, facility) -> distance
    valid_links = {}

    # (A) Self-loops: Distance 0.0
    for i in range(num_nodes):
        valid_links[(i, i)] = 0.0

    # (B) Graph Edges: Distance from edge_weight
    # Assumes input graph respects d <= 1 as per paper's graph construction
    row, col = data.edge_index
    weights = data.edge_weight
    for i, j, w in zip(row.tolist(), col.tolist(), weights.tolist()):
        valid_links[(i, j)] = w

    # --- 2. Initialize OR-Tools Solver ---
    # 'SCIP' is a high-performance backend solver included in OR-Tools for MIP
    solver = pywraplp.Solver.CreateSolver('SCIP')
    solver.SetTimeLimit(3600 * 1000 * 24)  # one day

    if not solver:
        raise Exception("SCIP solver not found.")

    # --- 3. Variables ---

    # y[j]: 1 if facility j is opened, 0 otherwise.
    y = {}
    for j in range(num_nodes):
        y[j] = solver.IntVar(0, 1, f'y_{j}')

    # x[i,j]: 1 if client i is served by facility j.
    x = {}
    for (i, j) in valid_links.keys():
        x[i, j] = solver.IntVar(0, 1, f'x_{i}_{j}')

    # --- 4. Objective Function (Equation 1) ---
    # Term 1: |F| => Sum of y[j]
    # Term 2: Sum d(x, F) => Sum of (dist_{ij} * x_{ij})

    objective = solver.Objective()

    # Add opening costs
    for j in range(num_nodes):
        objective.SetCoefficient(y[j], 1.0)

    # Add connection costs
    for (i, j), dist in valid_links.items():
        objective.SetCoefficient(x[i, j], float(dist))

    objective.SetMinimization()

    # --- 5. Constraints ---

    # (A) Assignment Constraint:
    # Every client i must be served by exactly one facility.
    # sum(x[i, j] for j in neighbors) == 1

    # Group neighbors for efficiency
    neighbors_of = {i: [] for i in range(num_nodes)}
    for client, facility in valid_links.keys():
        neighbors_of[client].append(facility)

    for i in range(num_nodes):
        constraint = solver.Constraint(1, 1, f'assign_{i}')
        for j in neighbors_of[i]:
            constraint.SetCoefficient(x[i, j], 1)

    # (B) Open Facility Constraint:
    # x[i, j] <= y[j]  =>  x[i, j] - y[j] <= 0
    for i, j in valid_links.keys():
        # Constraint: -infinity <= x - y <= 0
        constraint = solver.Constraint(-solver.infinity(), 0, f'link_{i}_{j}')
        constraint.SetCoefficient(x[i, j], 1)
        constraint.SetCoefficient(y[j], -1)

    # --- 6. Solve ---
    status = solver.Solve()

    # --- 7. Output ---
    if status == pywraplp.Solver.OPTIMAL or status == pywraplp.Solver.FEASIBLE:
        open_facilities = [j for j in range(num_nodes) if y[j].solution_value() > 0.5]

        # Calculate final breakdown
        final_connection_cost = sum(valid_links[(i, j)] for (i, j) in valid_links if x[i, j].solution_value() > 0.5)
        final_opening_cost = len(open_facilities)

        return {
            'solve_time': solver.wall_time(),
            'obj_val': objective.Value(),
            'term_opening': final_opening_cost,
            'term_connection': final_connection_cost,
            'facilities': open_facilities
        }
    else:
        raise Exception("The problem does not have an optimal solution.")


@hydra.main(version_base=None, config_path='./config', config_name="mpnn")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    solutions = []
    open_costs = []
    trans_costs = []
    sol_times = []

    log_dir = "./logs"
    output_file = os.path.join(log_dir, f"{args.train.datapath.split('/')[-1]}_solver_results.csv")
    os.makedirs(log_dir, exist_ok=True)

    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Graph_ID', 'Open_Cost', 'Trans_Cost', 'Total_Cost'])

    pbar = tqdm(test_set)
    for i, data in enumerate(pbar):
        sol = solve_equation1_ortools(data)
        sol_times.append(sol['solve_time'])
        solutions.append(sol['obj_val'])
        open_costs.append(sol['term_opening'])
        trans_costs.append(sol['term_connection'])
        pbar.set_postfix({'open': sol['term_opening'], 'trans': sol['term_connection'], 'total': sol['obj_val']})

        with open(output_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                i,  # Graph ID
                f"{sol['term_opening']}",
                f"{sol['term_connection']}",
                f"{sol['obj_val']}"
            ])
            f.flush()

    wandb.log({
        'time_mean': np.mean(sol_times),
        'time_std': np.std(sol_times),
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'trans_mean': np.mean(trans_costs),
        'trans_std': np.std(trans_costs),
        'total_mean': np.mean(solutions),
        'total_std': np.std(solutions),
    })


if __name__ == '__main__':
    main()
