import torch
import numpy as np
import pickle

def save_random_problems(batch_size, problem_size, charging_station_size):
    depot_list = list()
    cs_list = list()
    node_list = list()
    node_demand_list = list()
    if problem_size == 20:
        demand_scaler = 30
    elif problem_size == 50:
        demand_scaler = 40
    elif problem_size == 100:
        demand_scaler = 50
    elif problem_size == 200:
        demand_scaler = 100
    elif problem_size == 500:
        demand_scaler = 150
    elif problem_size == 1000:
        demand_scaler = 200
    elif problem_size == 5000:
        demand_scaler = 300
    elif problem_size == 7000:
        demand_scaler = 300
    else:
        raise NotImplementedError

    while len(depot_list) < batch_size:
        depot_xy = torch.rand(size=(1, 1, 2))
        cs_xy2 = depot_xy.expand(1,charging_station_size+1, 2)
        # shape: (batch, 1, 2)
        cs_xy = torch.rand(size=(1, charging_station_size, 2))
        cs_xy = torch.cat((depot_xy, cs_xy), -2)
        node_xy = torch.rand(size=(1, problem_size, 2))
        diff = depot_xy - node_xy
        distance_from_current_to_all_others = torch.norm(diff, dim=-1)
        diff2 = node_xy.unsqueeze(2) - cs_xy.unsqueeze(1)  # Shape: (2, 100, 106, 2)
        xxx = torch.norm(diff2, dim=-1)
        distance_from_all_nodes_to_css, _ = torch.min(xxx, -1)
        total_dist = distance_from_current_to_all_others + distance_from_all_nodes_to_css
        if (total_dist < 2).all():
            depot_list.append(depot_xy)
            cs_list.append(cs_xy)
            node_list.append(node_xy)
            node_demand = torch.randint(1, 10, size=(1, problem_size)) / float(demand_scaler)
            node_demand_list.append(node_demand)
    # shape: (batch, problem, 2)
    depot_xy = torch.cat(depot_list, dim=0)
    node_xy = torch.cat(node_list, dim=0)
    cs_xy = torch.cat(cs_list, dim=0)
    node_demand = torch.cat(node_demand_list, dim=0)

    # node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)
    my_dict = {
        'depot_xy': depot_xy,
        'cs_xy': cs_xy,
        'node_xy': node_xy,
        'node_demand': node_demand
    }
    torch.save(my_dict, f'EVRPCS_{batch_size}_{problem_size}.pt')
    with open(f'EVRPCS_{batch_size}_{problem_size}.pkl', 'wb') as f:
        pickle.dump(my_dict, f)
    return depot_xy, node_xy, cs_xy, node_demand


def get_random_problems(batch_size, problem_size, charging_station_size):
    depot_list = list()
    cs_list = list()
    node_list = list()
    while len(depot_list) < batch_size:
        depot_xy = torch.rand(size=(1, 1, 2))
        # shape: (batch, 1, 2)
        cs_xy = torch.rand(size=(1, charging_station_size, 2))
        cs_xy = torch.cat((depot_xy, cs_xy), -2)
        node_xy = torch.rand(size=(1, problem_size, 2))
        diff = depot_xy - node_xy
        distance_from_current_to_all_others = torch.norm(diff, dim=-1)
        diff2 = node_xy.unsqueeze(2) - cs_xy.unsqueeze(1)  # Shape: (2, 100, 106, 2)
        xxx = torch.norm(diff2, dim=-1)
        distance_from_all_nodes_to_css, _ = torch.min(xxx, -1)
        total_dist = distance_from_current_to_all_others + distance_from_all_nodes_to_css
        if (total_dist < 2).all():
            depot_list.append(depot_xy)
            cs_list.append(cs_xy)
            node_list.append(node_xy)
    # shape: (batch, problem, 2)
    depot_xy = torch.cat(depot_list, dim=0)
    node_xy = torch.cat(node_list, dim=0)
    cs_xy = torch.cat(cs_list, dim=0)
    if problem_size == 20:
        demand_scaler = 30
    elif problem_size == 50:
        demand_scaler = 40
    elif problem_size == 100:
        demand_scaler = 50
    elif problem_size == 200:
        demand_scaler = 100
    elif problem_size == 500:
        demand_scaler = 150
    elif problem_size == 1000:
        demand_scaler = 200
    elif problem_size == 5000:
        demand_scaler = 300
    elif problem_size == 7000:
        demand_scaler = 300
    else:
        raise NotImplementedError

    node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)
    # shape: (batch, problem)

    return depot_xy, node_xy, cs_xy, node_demand


def augment_xy_data_by_8_fold(xy_data):
    # xy_data.shape: (batch, N, 2)

    x = xy_data[:, :, [0]]
    y = xy_data[:, :, [1]]
    # x,y shape: (batch, N, 1)

    dat1 = torch.cat((x, y), dim=2)
    dat2 = torch.cat((1 - x, y), dim=2)
    dat3 = torch.cat((x, 1 - y), dim=2)
    dat4 = torch.cat((1 - x, 1 - y), dim=2)
    dat5 = torch.cat((y, x), dim=2)
    dat6 = torch.cat((1 - y, x), dim=2)
    dat7 = torch.cat((y, 1 - x), dim=2)
    dat8 = torch.cat((1 - y, 1 - x), dim=2)

    aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
    # shape: (8*batch, N, 2)

    return aug_xy_data


if __name__ == '__main__':
    save_random_problems(100, 7000,4)