import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from mm_cvrp.network import Net
from mm_cvrp.ortools_tsp import solve
from torch.distributions import Categorical


class Agentembedding(nn.Module):
    def __init__(self, node_feature_size: int, key_size: int, value_size: int) -> None:
        super(Agentembedding, self).__init__()
        self.key_size = key_size
        self.q_agent = nn.Linear(2 * node_feature_size, key_size)
        self.k_agent = nn.Linear(node_feature_size, key_size)
        self.v_agent = nn.Linear(node_feature_size, value_size)

    def forward(self, f_c: torch.tensor, f: torch.tensor) -> torch.tensor:
        """
        f_c :
        f : batch_size x num_location x 64
        """
        q = self.q_agent(f_c)
        k = self.k_agent(f)
        v = self.v_agent(f)
        u = torch.matmul(k, q.transpose(-1, -2)) / math.sqrt(self.key_size)
        u_ = F.softmax(u, dim=-2).transpose(-1, -2)
        agent_embedding = torch.matmul(u_, v)

        return agent_embedding


class AgentAndNode_embedding(torch.nn.Module):
    """
    preprocess
    """

    def __init__(self, in_chnl: int, hid_chnl: int, n_agent: int, key_size: int, value_size: int, dev: str) -> None:
        super(AgentAndNode_embedding, self).__init__()

        self.n_agent = n_agent

        # gin
        self.gin = Net(in_chnl=in_chnl, hid_chnl=hid_chnl).to(dev)
        # agent attention embed
        self.agents = torch.nn.ModuleList()
        for _ in range(n_agent):
            self.agents.append(
                Agentembedding(node_feature_size=hid_chnl, key_size=key_size, value_size=value_size).to(dev)
            )

    def forward(
        self, batch_graphs: torch_geometric.data.batch, n_nodes: int, n_batch: int
    ) -> tuple[torch.tensor, torch.tensor]:
        # get node embedding using gin
        nodes_h, g_h = self.gin(x=batch_graphs.x, edge_index=batch_graphs.edge_index, batch=batch_graphs.batch)
        # batch_size x num_location x 64
        nodes_h = nodes_h.reshape(n_batch, n_nodes, -1)
        # batch_size x num_location x 64
        g_h = g_h.reshape(n_batch, 1, -1)

        # batch_size x 1 (single depot) x 128
        depot_cat_g = torch.cat((g_h, nodes_h[:, 0, :].unsqueeze(1)), dim=-1)
        # output nodes embedding should not include depot, refer to paper: https://www.sciencedirect.com/science/article/abs/pii/S0950705120304445
        nodes_h_no_depot = nodes_h[:, 1:, :]

        # get agent embedding
        agents_embedding = []
        for i in range(self.n_agent):
            agents_embedding.append(self.agents[i](depot_cat_g, nodes_h_no_depot))

        agent_embeddings = torch.cat(agents_embedding, dim=1)

        return agent_embeddings, nodes_h_no_depot


class Policy(nn.Module):
    def __init__(
        self,
        in_chnl: int,
        hid_chnl: int,
        n_agent: int,
        key_size_embd: int,
        key_size_policy: int,
        val_size: int,
        clipping: int,
        dev: str,
        disable_softmax: bool,
    ) -> None:
        super(Policy, self).__init__()
        self.c = clipping
        self.disable_softmax = disable_softmax
        self.key_size_policy = key_size_policy
        self.key_policy = nn.Linear(hid_chnl, self.key_size_policy, device=dev)
        self.q_policy = nn.Linear(val_size, self.key_size_policy, device=dev)

        # embed network
        self.embed = AgentAndNode_embedding(
            in_chnl=in_chnl, hid_chnl=hid_chnl, n_agent=n_agent, key_size=key_size_embd, value_size=val_size, dev=dev
        )
        self.key_node_linear = nn.Linear(hid_chnl, 1, device=dev)

    def forward(self, batch_graph: torch_geometric.data.batch, n_nodes: int, n_batch: int) -> torch.Tensor:
        agent_embeddings, nodes_h_no_depot = self.embed(batch_graph, n_nodes, n_batch)

        k_policy = self.key_policy(nodes_h_no_depot)
        q_policy = self.q_policy(agent_embeddings)
        u_policy = torch.matmul(q_policy, k_policy.transpose(-1, -2)) / math.sqrt(self.key_size_policy)
        imp = self.c * torch.tanh(u_policy)
        prob = F.softmax(imp, dim=-2)
        if self.disable_softmax:
            k_policy2 = torch.sigmoid(self.key_node_linear(nodes_h_no_depot)) * 2
            k_policy2 = k_policy2.transpose(1, 2)
            k_policy3 = k_policy2 / k_policy2.sum()
            prob = prob * k_policy3
        # prob = F.relu(imp)

        return prob


def action_sample(pi: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    dist = Categorical(pi.transpose(2, 1))
    action = dist.sample()
    log_prob = dist.log_prob(action)
    return action, log_prob


def get_log_prob(pi: torch.Tensor, action_int: int) -> torch.Tensor:
    dist = Categorical(pi.transpose(2, 1))
    log_prob = dist.log_prob(action_int)
    return log_prob


def get_cost(action: torch.Tensor, data: torch.Tensor, n_agent: int, return_path: bool = False) -> list[float]:
    subtour_max_lengths = [0 for _ in range(data.shape[0])]
    data = data * 1000  # why?
    depot = data[:, 0, :].tolist()
    sub_tours = [[[] for _ in range(n_agent)] for _ in range(data.shape[0])]
    for i in range(data.shape[0]):
        for tour in sub_tours[i]:
            tour.append(depot[i])
        for agent_idx, location in zip(action.tolist()[i], data.tolist()[i][1:], strict=False):
            sub_tours[i][agent_idx].append(location)

    all_subtour_length_list = []
    all_path_list = []
    for k in range(data.shape[0]):
        subtour_length_list = []
        subtour_path_list = []
        for agent_idx in range(n_agent):
            instance = sub_tours[k][agent_idx]
            sub_tour_length, path = solve(instance, return_path=True)
            sub_tour_length = sub_tour_length / 1000
            subtour_path_list.append(path)
            subtour_length_list.append(sub_tour_length)
        subtour_max_lengths[k] = max(subtour_length_list)
        all_subtour_length_list.append(subtour_length_list)
        all_path_list.append(subtour_path_list)

    if not return_path:
        return subtour_max_lengths, all_subtour_length_list
    else:
        return subtour_max_lengths, all_subtour_length_list, all_path_list


def get_cost4plot(action: torch.Tensor, data: torch.Tensor, n_agent: int) -> list[float]:
    subtour_max_lengths = [0 for _ in range(data.shape[0])]
    # otroolsのTSP solverはint丸めした上で動かされているので適当に座標の値を大きくする
    data = data * 1000
    depot = data[:, 0, :].tolist()
    sub_tours = [[[] for _ in range(n_agent)] for _ in range(data.shape[0])]
    local_idx2global_idx = {i: {} for i in range(data.shape[0])}
    for i in range(data.shape[0]):
        for tour in sub_tours[i]:
            tour.append(depot[i])
        for idx, (agent_idx, location) in enumerate(zip(action.tolist()[i], data.tolist()[i][1:], strict=False)):
            local_idx = len(sub_tours[i][agent_idx])
            sub_tours[i][agent_idx].append(location)
            # action : [len(data)-1]
            global_idx = idx + 1
            local_idx2global_idx[i][(agent_idx, local_idx)] = global_idx

    subtour_order = []
    subtour_length_list = []
    for i in range(data.shape[0]):
        part_subtour_order = []
        part_subtour_length_list = []
        for agent_idx in range(n_agent):
            instance = sub_tours[i][agent_idx]
            sub_tour_length, order = solve(instance, return_path=True)
            part_subtour_order.append(order)
            sub_tour_length /= 1000
            part_subtour_length_list.append(sub_tour_length)
            if sub_tour_length >= subtour_max_lengths[i]:
                subtour_max_lengths[i] = sub_tour_length
        subtour_order.append(part_subtour_order)
        subtour_length_list.append(part_subtour_length_list)

    # FIXME : local search

    # OPTIMIZE : local_idx2global_idxで戻した後のindexを返す
    return subtour_max_lengths, subtour_order, subtour_length_list, local_idx2global_idx


class Surrogate(nn.Module):
    def __init__(
        self, in_dim: int, out_dim: int, n_hidden: int = 64, nonlin: str = "relu", dev="cpu", **kwargs
    ) -> None:
        super(Surrogate, self).__init__()
        nlist = dict(
            relu=nn.ReLU(),
            tanh=nn.Tanh(),
            sigmoid=nn.Sigmoid(),
            softplus=nn.Softplus(),
            lrelu=nn.LeakyReLU(),
            elu=nn.ELU(),
        )

        self.layer = nn.Linear(in_dim, n_hidden, device=dev)
        self.layer2 = nn.Linear(n_hidden, n_hidden, device=dev)
        self.out = nn.Linear(n_hidden, out_dim, device=dev)
        self.nonlin = nlist[nonlin]

    def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        x = x.reshape(x.shape[0], -1)
        x = self.layer(x)
        x = self.nonlin(x)
        x = self.layer2(x)
        x = self.nonlin(x)
        x = self.out(x)

        return x
