import numpy as np
import numba as nb
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from swapstar import swapstar
import concurrent.futures
from functools import partial

@dataclass
class Route_Info:
    BATCH_IDX: torch.Tensor
    POMO_IDX: torch.Tensor
    reward: torch.Tensor
    # shape: (batch, pomo)
    load: torch.Tensor
    # shape: (batch, pomo, seq_len)
    route: torch.Tensor = None
    # shape: (batch, pomo, seq_len)
    ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, seq_len, node)

class CVRPLocalSearch:
    def __init__(self, **search_param):
        self.search_param = search_param
        self.search_proportion = search_param['search_proportion']
        # the proportion of search samples in one batch
        self.proportion_type = search_param['proportion_type']
        self.number_of_cpu = search_param['number_of_cpu']

    def search(self, route, reward, distmat, problems, demand):
        # shape: (batch, pomo, problem)
        batch_size = route.size(0)
        pomo_size = route.size(1)
        seq_len = route.size(2)
        problem_size = problems.size(1)
        
        batch_index = torch.arange(batch_size).view(-1, 1)
        search_pomo_size = int(pomo_size * self.search_proportion)
        
        if self.proportion_type == 'random':
            search_pomo_idx = torch.randint(pomo_size, size=(batch_size, search_pomo_size))
            search_route = route[batch_index, search_pomo_idx]
            search_reward = reward[batch_index, search_pomo_idx]
        elif self.proportion_type == 'maximum':
            search_pomo_idx = torch.argsort(reward, dim=1, descending=True)[:, :search_pomo_size]
            search_route = route[batch_index, search_pomo_idx]
            search_reward = reward[batch_index, search_pomo_idx]
        else:
            raise NotImplementedError

        search_route = search_route.reshape(-1, seq_len).cpu()
        search_reward_np = search_reward.cpu().numpy()
        demand_np = demand.cpu().numpy()
        problems_np = problems.cpu().numpy()
        distmat_np = distmat.cpu().numpy()
        new_route, load = self.swap_star(search_route, distmat_np, demand_np, problems_np)

        load = load.view(batch_size, search_pomo_size, seq_len).to(route.device)
        new_route = new_route.view(batch_size, search_pomo_size, seq_len).to(route.device)
        packed = self.pack_route(new_route, load, problems)
        reward_increment = (packed.reward - search_reward).mean()
        # print("reward_increment:", reward_increment.item())
        return packed

    def pack_route(self, route, load, problems):     
        batch_size = route.size(0)
        search_pomo_size = route.size(1)
        seq_len = route.size(2)
        problem_size = problems.size(1)
        
        BATCH_IDX = torch.arange(batch_size)[:, None].expand(batch_size, search_pomo_size)
        POMO_IDX = torch.arange(search_pomo_size)[None, :].expand(batch_size, search_pomo_size)

        one_hot = F.one_hot(route, problem_size).to(torch.float)
        one_hot[route == 0] = 0
        # shape: (batch, pomo, problem, problem)
        till_mat = torch.tril(torch.ones(batch_size, search_pomo_size, seq_len, seq_len))
        mask = till_mat @ one_hot
        ninf_mask = torch.zeros_like(mask)
        ninf_mask = torch.where((till_mat @ one_hot) == 0, ninf_mask, float('-inf'))
        ninf_mask = ninf_mask.reshape(batch_size, search_pomo_size * seq_len, problem_size)
        gathering_index = route.unsqueeze(3).expand(batch_size, -1, seq_len, 2)
        # shape: (batch, pomo, problem, 2)
        seq_expanded = problems[:, None, :, :].expand(batch_size, search_pomo_size, problem_size, 2)

        ordered_seq = seq_expanded.gather(dim=2, index=gathering_index)
        # shape: (batch, pomo, problem, 2)

        rolled_seq = ordered_seq.roll(dims=2, shifts=-1)
        segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt()
        # shape: (batch, pomo, problem)

        travel_distances = segment_lengths.sum(2)
        return Route_Info(BATCH_IDX, POMO_IDX, -travel_distances, load, route, ninf_mask)
    
    def swap_star(self, route, dist, demand, problems):
        new_route, load = multiple_swap_star(route, demand, dist, problems)
        return new_route, load

def multiple_swap_star(paths, demand, distances, problems, device='cuda', indexes=None, max_iterations=1):
    batch_size = demand.shape[0]
    path_size = paths.shape[0]
    pomo_size = path_size // batch_size
    load = torch.empty_like(paths).to(torch.float32)
    subroutes_all = []
    for i in range(paths.size(0)) if indexes is None else indexes:
        subroutes = get_subroutes(paths[i, :])
        subroutes_all.append((i, subroutes))
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = []
        for i, p in subroutes_all:
            batch_id = i // pomo_size
            future = executor.submit(swapstar, 
                                     demand[batch_id], 
                                     distances[batch_id], 
                                     problems[batch_id], 
                                     p, 
                                     max_iterations
                                    )
            futures.append((i, future))
        for i, future in futures:
            batch_id = i // pomo_size
            paths[i, :], load[i, :] = merge_subroutes(future.result(), paths.size(1), demand[batch_id])
    
    return paths, load
            
def get_subroutes(route, end_with_zero=True):
    x = torch.nonzero(route == 0).flatten()
    subroutes = []
    for i, j in zip(x, x[1:]):
        if j - i > 1:
            if end_with_zero:
                j = j + 1
            subroutes.append(route[i:j])
    return subroutes

def merge_subroutes(subroutes, length, demand):
    route = torch.zeros(length, dtype = torch.long)
    load = torch.ones(length)
    i = 0
    for r in subroutes:
        if len(r) > 2:
            if isinstance(r, list):
                r = torch.tensor(r[:-1])
            else:
                r = r[:-1].clone().detach()
            route[i: i+len(r)] = r
            for j in range(1, len(r)):
                load[j + i] = load[j + i - 1] - demand[r[j]]
            i += len(r)
    assert((load > -1e-6).all()), load
    return route, load

# def get_route_load(route, demand):
#     batch_size = route.size(0)
#     pomo_size = route.size(1)
#     seq_len = route.size(2)
    
#     load_list = torch.zeros((batch_size, pomo_size, 0))
#     load = torch.ones((batch_size, pomo_size))
#     for route in range(seq_len):
#         load = load - demand[]
        
#         load = torch.zeros((batch_size, pomo_size))
        
        
        