from __future__ import print_function

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from mm_cvrp.network import Net
from mm_cvrp.ortools_tsp import solve
from ortools.constraint_solver import pywrapcp
from ortools.constraint_solver import routing_enums_pb2
from torch.distributions import Categorical


class Agentembedding(nn.Module):
    def __init__(self, node_feature_size: int, key_size: int, value_size: int) -> None:
        super(Agentembedding, self).__init__()
        self.key_size = key_size
        self.q_agent = nn.Linear(2 * node_feature_size, key_size)
        self.k_agent = nn.Linear(node_feature_size, key_size)
        self.v_agent = nn.Linear(node_feature_size, value_size)

    def forward(self, f_c: torch.tensor, f: torch.tensor) -> torch.tensor:
        """
        f_c :
        f : batch_size x num_location x 64
        """
        q = self.q_agent(f_c)
        k = self.k_agent(f)
        v = self.v_agent(f)
        u = torch.matmul(k, q.transpose(-1, -2)) / math.sqrt(self.key_size)
        u_ = F.softmax(u, dim=-2).transpose(-1, -2)
        agent_embedding = torch.matmul(u_, v)

        return agent_embedding


class AgentAndNode_embedding(torch.nn.Module):
    """
    preprocess
    """

    def __init__(self, in_chnl: int, hid_chnl: int, n_agent: int, key_size: int, value_size: int, dev: str) -> None:
        super(AgentAndNode_embedding, self).__init__()

        self.n_agent = n_agent

        # gin
        self.gin = Net(in_chnl=in_chnl, hid_chnl=hid_chnl).to(dev)
        # agent attention embed
        self.agents = torch.nn.ModuleList()
        for _ in range(n_agent):
            self.agents.append(
                Agentembedding(node_feature_size=hid_chnl, key_size=key_size, value_size=value_size).to(dev)
            )

    def forward(
        self, batch_graphs: torch_geometric.data.batch, n_nodes: int, n_batch: int
    ) -> tuple[torch.tensor, torch.tensor]:
        # get node embedding using gin
        nodes_h, g_h = self.gin(x=batch_graphs.x, edge_index=batch_graphs.edge_index, batch=batch_graphs.batch)
        # batch_size x num_location x 64
        nodes_h = nodes_h.reshape(n_batch, n_nodes, -1)
        # batch_size x num_location x 64
        g_h = g_h.reshape(n_batch, 1, -1)

        # batch_size x 1 (single depot) x 128
        depot_cat_g = torch.cat((g_h, nodes_h[:, 0, :].unsqueeze(1)), dim=-1)
        # output nodes embedding should not include depot, refer to paper: https://www.sciencedirect.com/science/article/abs/pii/S0950705120304445
        nodes_h_no_depot = nodes_h[:, 1:, :]

        # get agent embedding
        agents_embedding = []
        for i in range(self.n_agent):
            agents_embedding.append(self.agents[i](depot_cat_g, nodes_h_no_depot))

        agent_embeddings = torch.cat(agents_embedding, dim=1)

        return agent_embeddings, nodes_h_no_depot


class TSPSurrogate(nn.Module):
    def __init__(self, n_nodes: int, dev: str, hidden_dim=50):
        super(TSPSurrogate, self).__init__()
        # Define layers for key, query, and value
        self.query = nn.Linear(2, hidden_dim, device=dev)
        self.key = nn.Linear(2, hidden_dim, device=dev)
        self.value = nn.Linear(2, hidden_dim, device=dev)

        # Define the output MLP
        self.fc = nn.Linear(hidden_dim, 1, device=dev)

    def forward(self, x):
        # x is of shape (batch_size, n_node, 2)
        # Apply the linear layers to get query, key, and value
        query = self.query(x)  # (batch_size, n_node, hidden_dim)
        key = self.key(x)  # (batch_size, n_node, hidden_dim)
        value = self.value(x)  # (batch_size, n_node, hidden_dim)

        # Compute attention scores using dot-product
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (
            key.size(-1) ** 0.5
        )  # (batch_size, n_node, n_node)
        attention_weights = F.softmax(attention_scores, dim=-1)  # (batch_size, n_node, n_node)

        # Compute the weighted sum of values
        context = torch.matmul(attention_weights, value)  # (batch_size, n_node, hidden_dim)

        # Pooling: Summing across n_node dimension
        context = context.sum(dim=1)  # (batch_size, hidden_dim)

        # Pass through the final fully connected layer to get output
        output = F.relu(self.fc(context))  # (batch_size, 1)

        return output


class Policy(nn.Module):
    def __init__(
        self,
        n_nodes: int,
        in_chnl: int,
        hid_chnl: int,
        n_agent: int,
        key_size_embd: int,
        key_size_policy: int,
        val_size: int,
        clipping: int,
        dev: str,
        disable_softmax: bool,
    ) -> None:
        super(Policy, self).__init__()
        self.c = clipping
        self.disable_softmax = disable_softmax
        self.key_size_policy = key_size_policy
        self.key_policy = nn.Linear(hid_chnl, self.key_size_policy, device=dev)
        self.q_policy = nn.Linear(val_size, self.key_size_policy, device=dev)

        # embed network
        self.embed = AgentAndNode_embedding(
            in_chnl=in_chnl, hid_chnl=hid_chnl, n_agent=n_agent, key_size=key_size_embd, value_size=val_size, dev=dev
        )
        self.last_layer = nn.Linear(n_nodes - 1, 100, device=dev)
        self.last_layer2 = nn.Linear(100, 1, device=dev)

    def forward(self, batch_graph: torch_geometric.data.batch, n_nodes: int, n_batch: int) -> torch.Tensor:
        agent_embeddings, nodes_h_no_depot = self.embed(batch_graph, n_nodes, n_batch)

        k_policy = self.key_policy(nodes_h_no_depot)
        q_policy = self.q_policy(agent_embeddings)
        u_policy = torch.matmul(q_policy, k_policy.transpose(-1, -2)) / math.sqrt(self.key_size_policy)
        u_policy = u_policy.squeeze(1)
        z = F.relu(self.last_layer(u_policy))
        z = F.relu(self.last_layer2(z))

        return z


def action_sample(pi: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    dist = Categorical(pi.transpose(2, 1))
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action, log_prob


def get_log_prob(pi: torch.Tensor, action_int: int) -> torch.Tensor:
    dist = Categorical(pi.transpose(2, 1))
    log_prob = dist.log_prob(action_int)
    return log_prob


def create_data_model(instances):
    """Stores the data for the problem."""
    data = {"locations": instances, "num_vehicles": 1, "depot": 0}
    # Locations in block units
    return data


def compute_euclidean_distance_matrix(locations):
    """Creates callback to return distance between points."""
    distances = {}
    for from_counter, from_node in enumerate(locations):
        distances[from_counter] = {}
        for to_counter, to_node in enumerate(locations):
            if from_counter == to_counter:
                distances[from_counter][to_counter] = 0
            else:
                # Euclidean distance
                distances[from_counter][to_counter] = int(
                    math.hypot((from_node[0] - to_node[0]), (from_node[1] - to_node[1]))
                )
    return distances


def solve(instance, return_path=False) -> float | tuple[float, list[int]]:
    """Entry point of the program."""
    # Instantiate the data problem.
    data = create_data_model(instance)

    # Create the routing index manager.
    manager = pywrapcp.RoutingIndexManager(len(data["locations"]), data["num_vehicles"], data["depot"])

    # Create Routing Model.
    routing = pywrapcp.RoutingModel(manager)

    distance_matrix = compute_euclidean_distance_matrix(data["locations"])

    def distance_callback(from_index, to_index):
        """Returns the distance between the two nodes."""
        # Convert from routing variable Index to distance matrix NodeIndex.
        from_node = manager.IndexToNode(from_index)
        to_node = manager.IndexToNode(to_index)
        return distance_matrix[from_node][to_node]

    transit_callback_index = routing.RegisterTransitCallback(distance_callback)

    # Define cost of each arc.
    routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)

    # Setting first solution heuristic.
    search_parameters = pywrapcp.DefaultRoutingSearchParameters()
    search_parameters.first_solution_strategy = routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC

    # Solve the problem.
    solution = routing.SolveWithParameters(search_parameters)

    # Get route
    plan_output = []
    index = routing.Start(0)
    while not routing.IsEnd(index):
        plan_output.append(manager.IndexToNode(index))
        index = solution.Value(routing.NextVar(index))
    plan_output.append(manager.IndexToNode(index))

    # Print solution on console.
    # if solution:
    #     print_solution(manager, routing, solution)

    if not return_path:
        return solution.ObjectiveValue()
    else:
        return solution.ObjectiveValue(), plan_output


def get_cost(data: torch.Tensor) -> list[float]:
    tour_length_list = [0 for _ in range(data.shape[0])]
    # ortoolsの中でint丸めされた上で計算されるため予め値を大きくしておく
    data = data * 1000

    for k in range(data.shape[0]):
        print(k, data.shape[0])
        tour_length = solve(data[k])
        tour_length_list[k] = tour_length / 1000

    return tour_length_list


def get_cost4plot(action: torch.Tensor, data: torch.Tensor, n_agent: int) -> list[float]:
    subtour_max_lengths = [0 for _ in range(data.shape[0])]
    data = data * 1000
    depot = data[:, 0, :].tolist()
    sub_tours = [[[] for _ in range(n_agent)] for _ in range(data.shape[0])]
    local_idx2global_idx = {i: {} for i in range(data.shape[0])}
    for i in range(data.shape[0]):
        for tour in sub_tours[i]:
            tour.append(depot[i])
        for idx, (agent_idx, location) in enumerate(zip(action.tolist()[i], data.tolist()[i][1:], strict=False)):
            local_idx = len(sub_tours[i][agent_idx])
            sub_tours[i][agent_idx].append(location)
            # action : [len(data)-1]
            global_idx = idx + 1
            local_idx2global_idx[i][(agent_idx, local_idx)] = global_idx

    subtour_order = []
    subtour_length_list = []
    for i in range(data.shape[0]):
        part_subtour_order = []
        part_subtour_length_list = []
        for agent_idx in range(n_agent):
            instance = sub_tours[i][agent_idx]
            sub_tour_length, order = solve(instance, return_path=True)
            part_subtour_order.append(order)
            sub_tour_length /= 1000
            part_subtour_length_list.append(sub_tour_length)
            if sub_tour_length >= subtour_max_lengths[i]:
                subtour_max_lengths[i] = sub_tour_length
        subtour_order.append(part_subtour_order)
        subtour_length_list.append(part_subtour_length_list)

    # OPTIMIZE : local_idx2global_idxで戻した後のindexを返す
    return subtour_max_lengths, subtour_order, subtour_length_list, local_idx2global_idx


class Surrogate(nn.Module):
    def __init__(
        self, in_dim: int, out_dim: int, n_hidden: int = 64, nonlin: str = "relu", dev="cpu", **kwargs
    ) -> None:
        super(Surrogate, self).__init__()
        nlist = dict(
            relu=nn.ReLU(),
            tanh=nn.Tanh(),
            sigmoid=nn.Sigmoid(),
            softplus=nn.Softplus(),
            lrelu=nn.LeakyReLU(),
            elu=nn.ELU(),
        )

        self.layer = nn.Linear(in_dim, n_hidden, device=dev)
        self.layer2 = nn.Linear(n_hidden, n_hidden, device=dev)
        self.out = nn.Linear(n_hidden, out_dim, device=dev)
        self.nonlin = nlist[nonlin]

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        x = x.reshape(x.shape[0], -1)
        x = self.layer(x)
        x = self.nonlin(x)
        x = self.layer2(x)
        x = self.nonlin(x)
        x = self.out(x)

        return x
