from torch.utils.data import Dataset   
import torch
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import pdist, squareform

from problems.tsp.state_dtsp import StateDTSP

def generate_training_set(batch_size, proportions, distributions, means, covariances, min_size=20, max_size=50, min_new=20, max_new=50):

    num_nodes = np.random.randint(low=min_size, high=max_size+1)
    num_new = np.random.randint(low=min_new, high=max_new+1)

    nodes_coords = list(np.random.random([batch_size, num_nodes, 2]))

    times_batch = []
    new_coords_batch = []
    probabilities = []

    for distribution in distributions:
        if distribution == "am":
            p = np.zeros_like(range(2, num_nodes + num_new))
            p[:len(p)//2] = 2
            p[len(p)//2:] = 1
            p = p/np.sum(p)
        elif distribution == "pm":
            p = np.zeros_like(range(2, num_nodes + num_new))
            p[:len(p)//2] = 1
            p[len(p)//2:] = 2
            p = p/np.sum(p)
        elif distribution == "equal":
            p = np.ones_like(range(2, num_nodes + num_new))
            p = p/np.sum(p)
        probabilities.append(p)

    for _ in range(batch_size):
        n_loc = np.random.multinomial(num_new, proportions, size=1)[0]
        times_list = []
        new_coords_list = []
        available_times = list(range(2, num_nodes + num_new))
        for location in range(len(proportions)):
            probs = probabilities[location][[i-2 for i in available_times]]/np.sum(probabilities[location][[i-2 for i in available_times]])
            times = np.random.choice(available_times, n_loc[location], replace=False, p=probs)
            times_list.extend(times)
            available_times = [x for x in available_times if x not in times]
            new_coords_list.extend(truncated_bivariate_normal([0, 1], [0, 1], means[location], covariances[location], n_loc[location]))

        sorted_indices = np.argsort(times_list)

        times_batch.append(np.array(times_list)[sorted_indices])
        new_coords_batch.append(np.array(new_coords_list)[sorted_indices])

    return nodes_coords, new_coords_batch, times_batch

def nearest_neighbor_graph(nodes, neighbors, knn_strat):

    """Returns k-Nearest Neighbor graph as a **NEGATIVE** adjacency matrix
    """
    num_nodes = len(nodes)
    # If `neighbors` is a percentage, convert to int
    if knn_strat == 'percentage':
        neighbors = int(num_nodes * neighbors)
    
    if neighbors >= num_nodes-1 or neighbors == -1:
        W = np.zeros((num_nodes, num_nodes))
    else:
        # Compute distance matrix
        W_val = squareform(pdist(nodes, metric='euclidean'))
        W = np.ones((num_nodes, num_nodes))
        
        # Determine k-nearest neighbors for each node
        knns = np.argpartition(W_val, kth=neighbors, axis=-1)[:, neighbors::-1]
        # Make connections
        for idx in range(num_nodes):
            W[idx][knns[idx]] = 0
    
    # Remove self-connections
    np.fill_diagonal(W, 1)
    return W


def tour_nodes_to_W(tour_nodes):
    """Computes edge adjacency matrix representation of tour
    """
    num_nodes = len(tour_nodes)
    tour_edges = np.zeros((num_nodes, num_nodes))
    for idx in range(len(tour_nodes) - 1):
        i = tour_nodes[idx]
        j = tour_nodes[idx + 1]
        tour_edges[i][j] = 1
        tour_edges[j][i] = 1
    # Add final connection
    tour_edges[j][tour_nodes[0]] = 1
    tour_edges[tour_nodes[0]][j] = 1
    return tour_edges


def truncated_bivariate_normal(alimits, blimits, mean, cov, n_samples):
    samples = np.zeros((0, 2))   # 2 columns now
    while samples.shape[0] < n_samples: 
        s = np.random.multivariate_normal(mean, cov, size=(n_samples,))
        accepted = s[(np.min(s - [alimits[0], blimits[0]], axis=1) >= 0) & (np.max(s - [alimits[1], blimits[1]], axis=1) <= 0)]
        samples = np.concatenate((samples, accepted), axis=0)
    samples = samples[:n_samples, :]
    return samples


class DTSP(object):
    """Class representing the Dynamic Travelling Salesman Problem"""

    NAME = 'dtsp' 

    @staticmethod
    def get_costs(dataset, pi):
        """Returns DTSP tour length for given graph nodes and tour permutations
            This should remain fine provided that it is only used to find the cost of completed DTSP tours


        Args:
            dataset: graph nodes (torch.Tensor)
            pi: node permutations representing tours (torch.Tensor)
            
        Returns:
            TSP tour length, None
        """

        # Check that tours are valid, i.e. contain 0 to n -1
        assert (
            torch.arange(pi.size(1), out=pi.data.new()).view(1, -1).expand_as(pi) ==
            pi.data.sort(1)[0]
        ).all(), "Invalid tour:\n{}\n{}".format(dataset, pi)

        # Gather dataset in order of tour
        d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset))

        # Length is distance (L2-norm of difference) from each next location from its prev and of last from first
        return (d[:, 1:] - d[:, :-1]).norm(p=2, dim=2).sum(1) + (d[:, 0] - d[:, -1]).norm(p=2, dim=1), None
    
    @staticmethod
    def make_dataset(*args, **kwargs):
        return DTSPDataset(*args, **kwargs)
    
    @staticmethod
    def make_state(*args, **kwargs):
        return StateDTSP.initialize(*args, **kwargs)
    
    # Do not worry about beam search at the moment ...


class DTSPDataset(Dataset): # RL only at this point

    def __init__(self, filename=None, min_size=20, max_size=50, min_new=20, max_new=50, batch_size=128,
                 num_samples=128000, offset=0, distribution=None, neighbors=20, knn_strat=None, supervised=False, nar=False):
        
        super(DTSPDataset, self).__init__()

        self.filename = filename
        self.min_size = min_size
        self.max_size = max_size
        self.min_new = min_new
        self.max_new = max_new
        self.batch_size = batch_size
        self.num_samples = num_samples
        self.distribution = distribution
        self.offset = offset
        self.neighbors = neighbors
        self.knn_strat = knn_strat
        self.supervised = supervised
        self.nar = nar

        if filename is not None:
            self.nodes_coords = []
            self.new_nodes_coords = []
            self.times = []
            self.tour_nodes = []

            print('\nLoading from {}...'.format(filename))
            for line in tqdm(open(filename, "r").readlines()[offset:offset+num_samples], ascii=True):
                line = line.split(" ")
                num_old_nodes = int(line.index('new_nodes')//2)
                num_new_nodes = int(line.index('times')//2) - num_old_nodes
                self.nodes_coords.append(
                    [[float(line[idx]), float(line[idx + 1])] for idx in range(0, 2 * num_old_nodes, 2)]
                )
                self.new_nodes_coords.append(
                    [[float(line[idx]), float(line[idx + 1])] for idx in range(2 * num_old_nodes + 1, 2*(num_new_nodes + num_old_nodes) + 1,2)]
                )
                self.times.append(
                    [int(time) for time in line[line.index('times') + 1:line.index('output')]]
                )

                if self.supervised:
                    # Convert tour nodes to required format
                    # Don't add final connection for tour/cycle
                    tour_nodes = [int(node) - 1 for node in line[line.index('output') + 1:-1]][:-1]
                    self.tour_nodes.append(tour_nodes)

        # Generating random TSP samples (usually used for Reinforcement Learning)
        else:
            
            self.nodes_coords = []
            self.new_nodes_coords = []
            self.times = []

            print('\nGenerating {} samples of DTSP{}-{} with {}-{} new nodes'.format(
                num_samples, min_size, max_size, min_new, max_new))
            for _ in tqdm(range(num_samples//batch_size), ascii=True):
                # num_nodes = np.random.randint(low=min_size, high=max_size+1)
                # num_new = np.random.randint(low=min_new, high=max_new+1)
                # self.nodes_coords += list(np.random.random([batch_size, num_nodes, 2]))
                # self.times +=  list([np.sort(np.random.choice(range(1, num_nodes + num_new - 1), num_new, replace=False)) for i in range(batch_size)])
                # self.new_nodes_coords += list([truncated_bivariate_normal([0, 1], [0, 1], [0.2, 0.2], [[0.1, 0], [0, 0.1]], num_new) for i in range(batch_size)])
                nodes_coords, new_nodes_coords, times = generate_training_set(batch_size, [1], ["equal"], [[0.2,0.2]], [[[0.1, 0], [0, 0.1]]], min_size, max_size, min_new, max_new)
                self.nodes_coords += nodes_coords
                self.new_nodes_coords += new_nodes_coords
                self.times += times


        self.size = len(self.nodes_coords)
        assert self.size % batch_size == 0, \
            "Number of samples ({}) must be divisible by batch size ({})".format(self.size, batch_size)
            
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        nodes = self.nodes_coords[idx]
        new_nodes = self.new_nodes_coords[idx]
        time = self.times[idx]
        item = {
            'nodes': torch.FloatTensor(nodes),
            'new_nodes': torch.FloatTensor(new_nodes),
            'times': torch.LongTensor(np.array(time)),
            'graph': ~torch.BoolTensor(nearest_neighbor_graph(nodes, self.neighbors, self.knn_strat))
        }
        if self.supervised:
            # Add groundtruth labels in case of SL
            tour_nodes = self.tour_nodes[idx]
            item['tour_nodes'] = torch.LongTensor(tour_nodes)
            if self.nar:
                # Groundtruth for NAR decoders is the TSP tour in adjacency matrix format
                item['tour_edges'] = torch.LongTensor(tour_nodes_to_W(tour_nodes))
        return item
            
