import torch
import numpy as np
from statistics import mean, stdev as std
from utils_execution import compute_tour_cost
from copy import deepcopy
import time

BEAM_WIDTH = 100

# @torch.compile
def expand_single(beam_vis, bvnz, beam_last, beam_cost, beam_par, W):
    _MAYBE_CUDA_TRUE = torch.tensor(True, device=bvnz.device)
    added_cost = W[beam_last.expand_as(bvnz), bvnz]
    new_beam_cost = beam_cost + added_cost # broadcasting (,)+(N,)=(N,)
    new_beam_vis = beam_vis.repeat(bvnz.shape[0], 1)
    new_beam_vis[torch.arange(bvnz.shape[0], device=bvnz.device), bvnz] = _MAYBE_CUDA_TRUE
    new_beam_par = beam_par.repeat(bvnz.shape[0], 1)
    new_beam_par[torch.arange(bvnz.shape[0], device=bvnz.device), bvnz] = beam_last
    return new_beam_vis, bvnz, new_beam_cost, new_beam_par

expand_multiple = torch.func.vmap(expand_single, in_dims=(0, 0, 0, 0, 0, None), out_dims=(0, 0, 0, 0))

def beam_search_rollout(start_route, W, beam_width):
    snode = start_route.argmax()
    # beam = [
    #     ([start_route.argmax().item()],
    #      start_route.bool(),
    #      0)
    # ]
    num_nodes = W.shape[0]
    beam_par = torch.arange(start_route.shape[0], device=start_route.device).unsqueeze(0)
    beam_vis = start_route.bool().unsqueeze(0)
    beam_last = start_route.argmax().unsqueeze(0)
    beam_cost = torch.zeros(1).to(beam_vis.device).float()

    for _ in range(num_nodes-1):
        new_beam = []
        bvnz = (~beam_vis).nonzero().T
        bvnzz = bvnz[0]
        yshape = (bvnz[0] == 0).sum() # expands from each beam will be always the same
        bvnz = bvnz[1].reshape(-1, yshape)
        # new_beam_vis, new_beam_last, new_beam_cost, new_beam_par = [], [], [], []
        # for i in range(beam_vis.shape[0]):
        #     bvnz = (~beam_vis[i]).nonzero()
        #     bvnz = bvnz[:, 0]
        #     nbv, nbl, nbc, nbp = expand_single(beam_vis[i], bvnz, beam_last[i], beam_cost[i], beam_par[i], W)
        #     new_beam_vis.append(nbv)
        #     new_beam_last.append(nbl)
        #     new_beam_cost.append(nbc)
        #     new_beam_par.append(nbp)
        # beam_vis = torch.cat(new_beam_vis, dim=0)
        # beam_last = torch.cat(new_beam_last, dim=0)
        # beam_cost = torch.cat(new_beam_cost, dim=0)
        # beam_par = torch.cat(new_beam_par, dim=0)
        new_beam_vis, new_beam_last, new_beam_cost, new_beam_par = expand_multiple(beam_vis, bvnz, beam_last, beam_cost, beam_par, W)
        beam_vis = new_beam_vis.flatten(end_dim=1)
        beam_last = new_beam_last.flatten(end_dim=1)
        beam_cost = new_beam_cost.flatten(end_dim=1)
        beam_par = new_beam_par.flatten(end_dim=1)
        indices = torch.topk(beam_cost, min(beam_width, beam_par.shape[0]), largest=False, sorted=False).indices
        beam_vis = beam_vis[indices]
        beam_last = beam_last[indices]
        beam_cost = beam_cost[indices]
        beam_par = beam_par[indices]
        # for i in range(beam_vis.shape[0]):
        #     bvnz = (~beam_vis[i]).nonzero()
        #     new_beam_vis, new_beam_last, new_beam_cost, new_beam_par = expand_single(beam_vis[i], bvnz, beam_last[i], beam_cost[i], beam_par[i], W)
        #     breakpoint()
        # new_beam = sorted(new_beam, key=lambda x: x[2])
        # beam = new_beam[:beam_width]
    best_index = beam_cost.argmin()
    best_par = beam_par[best_index]
    best_par[snode] = beam_last[best_index]


    return torch.stack([torch.arange(best_par.shape[0], device=best_par.device), best_par], 0)

vmapped_beam_search_rollout = torch.func.vmap(beam_search_rollout, in_dims=(0, 0, None))


def beam_search_baseline(data):
    num_nodes = data[0].x.shape[0]
    srs = torch.stack(tuple(data[i].start_route for i in range(data.num_samples)))
    eas = torch.stack(tuple(data[i].edge_attr.reshape(num_nodes, num_nodes) for i in range(data.num_samples)))
    # tours = vmapped_beam_search_rollout(srs, eas, BEAM_WIDTH)
    # breakpoint()
    tours = [
        beam_search_rollout(data[i].start_route.clone(),
                            data[i].edge_attr.reshape(num_nodes, num_nodes).clone(),
                            BEAM_WIDTH)
        for i in range(data.num_samples)
    ]
    # tours = torch.Tensor(tours).long()

    tour_lengths = [compute_tour_cost(y, x.edge_attr).item() for x, y in zip(data, tours)]

    return mean(tour_lengths), std(tour_lengths)
