from dataclasses import dataclass
import torch
import numpy as np
import pickle
from .problem import Problem
from .generator_vrptw import InstanceGenVRPTW

@dataclass
class Problem_Data:
    problem_name: str = "vrptw"
    problem_size: int = None
    capacity: int = None
    depot_node_xy: np.array = None
    depot_node_demand: np.array = None
    depot_node_tw: np.array = None
    depot_node_sd: np.array = None  # service duration


@dataclass
class Problem_Feat:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)
    depot_tw: torch.Tensor = None
    node_tw: torch.Tensor = None
    node_sd: torch.Tensor = None

class ProblemVRPTW(Problem):

    def __init__(self, problem_size, generator_params=None):
        self.name = "vrptw"
        self.problem_size = problem_size

        self.FLAG__use_saved_problems = False
        self.saved_index = None
        self.dataset_depot_xy = None
        self.dataset_node_xy = None
        self.dataset_node_demand = None
        self.dataset_capacity = None
        self.dataset_depot_tw = None
        self.dataset_node_tw = None
        self.dataset_service_duration = None

        if generator_params is not None:
            self.instanceGen = InstanceGenVRPTW(self.problem_size, **generator_params)

    def load_problem_dataset_pkl(self, filename, num_problems, index_begin):
        raise NotImplementedError

    def load_problem_dataset_pt(self, filename, device):
        self.FLAG__use_saved_problems = True

        loaded_dict = torch.load(filename, map_location=device, weights_only=False)
        self.dataset_depot_xy = loaded_dict['depot_xy']
        self.dataset_node_xy = loaded_dict['node_xy']
        self.dataset_node_demand = loaded_dict['node_demand']
        self.dataset_capacity = loaded_dict['capacity']
        self.dataset_depot_tw = loaded_dict['depot_tw']
        self.dataset_node_tw = loaded_dict['node_tw']
        self.dataset_service_duration = loaded_dict['node_sd']
        self.saved_index = 0


    def init_problems(self, nb_instances, aug_factor):
        if not self.FLAG__use_saved_problems:
            depot_xy, node_xy, node_demand, capacity, depot_tw, node_tw, node_sd = self.get_random_problems(nb_instances)
        else:
            depot_xy = self.dataset_depot_xy[self.saved_index:self.saved_index + nb_instances]
            node_xy = self.dataset_node_xy[self.saved_index:self.saved_index + nb_instances]
            node_demand = self.dataset_node_demand[self.saved_index:self.saved_index + nb_instances]
            capacity = self.dataset_capacity[self.saved_index:self.saved_index + nb_instances]
            depot_tw = self.dataset_depot_tw[self.saved_index:self.saved_index + nb_instances]
            node_tw = self.dataset_node_tw[self.saved_index:self.saved_index + nb_instances]
            node_sd = self.dataset_service_duration[self.saved_index:self.saved_index + nb_instances]
            self.saved_index += nb_instances

        if aug_factor > 1:
            if aug_factor == 8:
                batch_size = nb_instances * 8
                depot_xy = self.augment_xy_data_by_8_fold(depot_xy)
                node_xy = self.augment_xy_data_by_8_fold(node_xy)
                node_demand = node_demand.repeat(8, 1)
                capacity = capacity.repeat(8, 1)
                depot_tw = depot_tw.repeat(8, 1, 1)
                node_tw = node_tw.repeat(8, 1, 1)
                node_sd = node_sd.repeat(8, 1)
            elif aug_factor % 8 == 0:
                batch_size = nb_instances * aug_factor
                depot_xy = self.augment_xy_data_by_8_fold(depot_xy)
                node_xy = self.augment_xy_data_by_8_fold(node_xy)
                depot_xy = depot_xy.repeat(aug_factor//8, 1, 1)
                node_xy = node_xy.repeat(aug_factor//8, 1, 1)
                node_demand = node_demand.repeat(aug_factor, 1)
                capacity = capacity.repeat(aug_factor, 1)
                depot_tw = depot_tw.repeat(aug_factor, 1, 1)
                node_tw = node_tw.repeat(aug_factor, 1, 1)
                node_sd = node_sd.repeat(aug_factor, 1)
            else:
                batch_size = nb_instances * aug_factor
                depot_xy = depot_xy.repeat(aug_factor, 1, 1)
                node_xy = node_xy.repeat(aug_factor, 1, 1)
                node_demand = node_demand.repeat(aug_factor, 1)
                capacity = capacity.repeat(aug_factor, 1)
                depot_tw = depot_tw.repeat(aug_factor, 1, 1)
                node_tw = node_tw.repeat(aug_factor, 1, 1)
                node_sd = node_sd.repeat(aug_factor, 1)
        else:
            batch_size = nb_instances

        problem_data = Problem_Data()
        depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        depot_demand = torch.zeros(size=(batch_size, 1), dtype=torch.int)
        depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        depot_node_tw = torch.cat((depot_tw, node_tw), dim=1)
        problem_data.problem_size = node_demand.shape[1]
        problem_data.depot_node_xy = depot_node_xy.cpu().numpy()
        problem_data.depot_node_demand = depot_node_demand.cpu().numpy()
        problem_data.capacity = capacity.cpu().numpy()
        problem_data.depot_node_tw = depot_node_tw.cpu().numpy()
        depot_sd = torch.zeros(size=(batch_size, 1))
        depot_node_sd = torch.cat((depot_sd, node_sd), dim=1)
        problem_data.depot_node_sd = depot_node_sd.cpu().numpy()

        problem_feat = Problem_Feat()
        problem_feat.depot_xy = depot_xy
        problem_feat.node_xy = node_xy
        problem_feat.node_demand = node_demand / capacity
        problem_feat.depot_tw = depot_tw
        problem_feat.node_tw = node_tw
        problem_feat.node_service_duration = node_sd

        return batch_size, problem_data, problem_feat


    def augment_xy_data_by_8_fold(self, xy_data):
        # xy_data.shape: (batch, N, 2)

        x = xy_data[:, :, [0]]
        y = xy_data[:, :, [1]]
        # x,y shape: (batch, N, 1)

        dat1 = torch.cat((x, y), dim=2)
        dat2 = torch.cat((1 - x, y), dim=2)
        dat3 = torch.cat((x, 1 - y), dim=2)
        dat4 = torch.cat((1 - x, 1 - y), dim=2)
        dat5 = torch.cat((y, x), dim=2)
        dat6 = torch.cat((1 - y, x), dim=2)
        dat7 = torch.cat((y, 1 - x), dim=2)
        dat8 = torch.cat((1 - y, 1 - x), dim=2)

        aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
        # shape: (8*batch, N, 2)

        return aug_xy_data



    def get_random_problems(self, batch_size):
        return self.instanceGen.get_random_problems(batch_size)
