from dataclasses import dataclass
import torch
import numpy as np
import pickle
from .problem import Problem
from .generator_cvrp import InstanceGenCVRP

@dataclass
class Problem_Data:
    problem_name: str = "cvrp"
    problem_size: int = None
    capacity: int = None
    depot_node_xy: np.array = None
    depot_node_demand: np.array = None


@dataclass
class Problem_Feat:
    depot_xy: torch.Tensor = None
    # shape: (batch, 1, 2)
    node_xy: torch.Tensor = None
    # shape: (batch, problem, 2)
    node_demand: torch.Tensor = None
    # shape: (batch, problem)

class ProblemCVRP(Problem):

    def __init__(self, problem_size, generator_params=None):
        self.name = "cvrp"
        self.problem_size = problem_size

        self.FLAG__use_saved_problems = False
        self.saved_index = None
        self.dataset_depot_xy = None
        self.dataset_node_xy = None
        self.dataset_node_demand = None
        self.dataset_capacity = None

        if generator_params is not None:
            self.instanceGen = InstanceGenCVRP(self.problem_size, **generator_params)

    def load_problem_dataset_pkl(self, filename, num_problems, index_begin):
        self.FLAG__use_saved_problems = True

        with open(filename, 'rb') as pickle_file:
            data = pickle.load(pickle_file)

        depot_data = list(data[i][0] for i in range(index_begin, index_begin + num_problems))
        self.dataset_depot_xy = torch.tensor(depot_data)[:, None, :]
        # shape: (batch, 1, 2)

        node_data = list(data[i][1] for i in range(index_begin, index_begin + num_problems))
        self.dataset_node_xy = torch.tensor(node_data)
        # shape: (batch, problem, 2)

        demand_data = list(data[i][2] for i in range(index_begin, index_begin + num_problems))
        capacity_data = list(data[i][3] for i in range(index_begin, index_begin + num_problems))

        self.dataset_capacity = torch.tensor(capacity_data)[:, None]
        self.dataset_node_demand = torch.tensor(demand_data)
        self.saved_index = 0

    def load_problem_dataset_pt(self, filename, device):
        self.FLAG__use_saved_problems = True

        loaded_dict = torch.load(filename, map_location=device, weights_only=False)
        self.dataset_depot_xy = loaded_dict['depot_xy']
        self.dataset_node_xy = loaded_dict['node_xy']
        self.dataset_node_demand = loaded_dict['node_demand']
        self.dataset_capacity = loaded_dict['capacity']
        self.saved_index = 0


    def init_problems(self, nb_instances, aug_factor):
        if not self.FLAG__use_saved_problems:
            depot_xy, node_xy, node_demand, capacity = self.get_random_problems(nb_instances)
        else:
            depot_xy = self.dataset_depot_xy[self.saved_index:self.saved_index + nb_instances]
            node_xy = self.dataset_node_xy[self.saved_index:self.saved_index + nb_instances]
            node_demand = self.dataset_node_demand[self.saved_index:self.saved_index + nb_instances]
            capacity = self.dataset_capacity[self.saved_index:self.saved_index + nb_instances]
            self.saved_index += nb_instances

        if aug_factor > 1:
            if aug_factor == 8:
                batch_size = nb_instances * 8
                depot_xy = self.augment_xy_data_by_8_fold(depot_xy)
                node_xy = self.augment_xy_data_by_8_fold(node_xy)
                node_demand = node_demand.repeat(8, 1)
                capacity = capacity.repeat(8, 1)
            elif aug_factor % 8 == 0:
                batch_size = nb_instances * aug_factor
                depot_xy = self.augment_xy_data_by_8_fold(depot_xy)
                node_xy = self.augment_xy_data_by_8_fold(node_xy)
                depot_xy = depot_xy.repeat(aug_factor//8, 1, 1)
                node_xy = node_xy.repeat(aug_factor//8, 1, 1)
                node_demand = node_demand.repeat(aug_factor, 1)
                capacity = capacity.repeat(aug_factor, 1)
            else:
                batch_size = nb_instances * aug_factor
                depot_xy = depot_xy.repeat(aug_factor, 1, 1)
                node_xy = node_xy.repeat(aug_factor, 1, 1)
                node_demand = node_demand.repeat(aug_factor, 1)
                capacity = capacity.repeat(aug_factor, 1)
        else:
            batch_size = nb_instances

        problem_data = Problem_Data()
        depot_node_xy = torch.cat((depot_xy, node_xy), dim=1)
        depot_demand = torch.zeros(size=(batch_size, 1), dtype=torch.int)
        depot_node_demand = torch.cat((depot_demand, node_demand), dim=1)
        problem_data.problem_size = node_demand.shape[1]
        problem_data.depot_node_xy = depot_node_xy.cpu().numpy()
        problem_data.depot_node_demand = depot_node_demand.cpu().numpy()
        problem_data.capacity = capacity.cpu().numpy()

        problem_feat = Problem_Feat()
        problem_feat.depot_xy = depot_xy
        problem_feat.node_xy = node_xy
        problem_feat.node_demand = node_demand / capacity

        return batch_size, problem_data, problem_feat


    def augment_xy_data_by_8_fold(self, xy_data):
        # xy_data.shape: (batch, N, 2)

        x = xy_data[:, :, [0]]
        y = xy_data[:, :, [1]]
        # x,y shape: (batch, N, 1)

        dat1 = torch.cat((x, y), dim=2)
        dat2 = torch.cat((1 - x, y), dim=2)
        dat3 = torch.cat((x, 1 - y), dim=2)
        dat4 = torch.cat((1 - x, 1 - y), dim=2)
        dat5 = torch.cat((y, x), dim=2)
        dat6 = torch.cat((1 - y, x), dim=2)
        dat7 = torch.cat((y, 1 - x), dim=2)
        dat8 = torch.cat((1 - y, 1 - x), dim=2)

        aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
        # shape: (8*batch, N, 2)

        return aug_xy_data



    def get_random_problems(self, batch_size):
        return self.instanceGen.get_random_problems(batch_size)
