import ot
import torch
from LinSATNet import linsat_layer
from mm_cvrp.policy import Policy
from mm_cvrp.policy import action_sample
from mm_cvrp.policy import get_cost
from torch_geometric.data import Batch
from torch_geometric.data import Data


def validate(instances, p_net, no_agent, device, src_vector):
    def calculate_optimal_transport(pi: torch.Tensor, n_nodes) -> torch.Tensor:
        pi_tmp = pi.reshape(len(pi), -1)

        # capacity constraint
        A = torch.zeros([no_agent, no_agent * (n_nodes - 1)], device=device)
        b = torch.tensor([25] * no_agent, dtype=torch.float32, device=device)
        for i in range(no_agent):
            A[i, i * (n_nodes - 1) : (i + 1) * (n_nodes - 1)] = 1

        # １つは必ずアサインされる(総和が１)
        # the column constrain
        E = torch.zeros([n_nodes - 1, no_agent * (n_nodes - 1)], device=device)
        column_gap = (n_nodes - 1) * torch.arange(no_agent)
        for i in range(n_nodes - 1):
            E[i, i + column_gap] = 1
        f = torch.ones(n_nodes - 1, dtype=torch.float32, device=device)

        output = linsat_layer(pi_tmp, A=A, b=b, E=E, f=f, tau=1e-4, max_iter=800)
        result = output.reshape(batch_size, no_agent, n_nodes - 1)
        return result

        output = None
        # 各地点の割当先を求めたいので、１が並んだベクトルを用意
        dst_vector = torch.ones(pi.shape[2], dtype=torch.float32, device=device)
        for i in range(pi.shape[0]):
            target = 1 - pi[i]
            P = ot.emd(src_vector, dst_vector, target)
            if abs(P.sum() - pi.shape[2]) >= 0.1:
                breakpoint()
            if output is None:
                output = P.unsqueeze(0)
            else:
                output = torch.cat((output, P.unsqueeze(0)), 0)

        return output

    batch_size = instances.shape[0]
    adj = torch.ones([batch_size, instances.shape[1], instances.shape[1]])  # adjacent matrix

    # get batch graphs instances list
    instances_list = [Data(x=instances[i], edge_index=torch.nonzero(adj[i]).t()) for i in range(batch_size)]
    # generate batch graph
    batch_graph = Batch.from_data_list(data_list=instances_list).to(device)

    # get pi
    pi = p_net(batch_graph, n_nodes=instances.shape[1], n_batch=batch_size)
    n_nodes = instances.shape[1]

    pi = calculate_optimal_transport(pi, n_nodes)
    # sample action and calculate log probs
    action, log_prob = action_sample(pi)

    # get reward for each batch
    reward, _ = get_cost(action, instances, no_agent)  # reward: tensor [batch, 1]
    # print('Validation result:', format(sum(reward)/batch_size, '.4f'))

    return sum(reward) / batch_size


if __name__ == "__main__":
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(2)

    n_agent = 5
    n_nodes = 50
    n_batch = 1000

    data = torch.load("./validation_data_" + str(n_nodes) + "_" + str(n_batch))

    policy = Policy(
        in_chnl=2,
        hid_chnl=32,
        n_agent=n_agent,
        key_size_embd=64,
        key_size_policy=64,
        val_size=64,
        clipping=10,
        dev=dev,
    )
    path = "./{}.pth".format(str(n_nodes) + "_" + str(n_agent))
    policy.load_state_dict(torch.load(path))
    validate(data, policy, n_agent, dev)
