from abc import ABC
from dataclasses import dataclass
import random
from utils.logging import get_logger
from typing import Tuple, List

from data import Solution, Employee, BusLeg
from stats import Stats



logger = get_logger(__name__)


@dataclass(frozen=True)
class DestroyOutput:
    legs: List[BusLeg]
    initial_solution: List[List[BusLeg]]
    n_employees: int
    employees: List[Employee] = None


class Destroyer(ABC):

    def __init__(self, stats: Stats) -> None:
        self.stats = stats

    def _remove_employees(self,
                          solution: Solution,
                          employees: set[Employee]) -> None:
        """Removes employees from solution

        Parameters
        ----------
        solution : Solution
            current solutions
        employees : List[Employee]
            employees that we want to remove
        """
        for employee in employees:
            solution.employees.remove(employee)

    def apply(self):
        pass


class EmployeeRemover(Destroyer):

    UNIFORM = 1  # select employees in a uniform way
    WEIGHTED = 2  # select employees according to the objective functions

    def __init__(self, stats: Stats, mode: int, sample_size: int) -> None:
        super().__init__(stats)
        self.mode = mode
        if self.mode == 1:
            self.name = 'EmployeeRemover(Uniform)'
        else:
            self.name = 'EmployeeRemover(Weighted)'
        self.sample_size = sample_size
        self.success = 0


    def get_employees(self, solution: Solution) -> List[Employee]:
        """get employees to remove

        Parameters
        ----------
        solution : Solution
            current solution

        Returns
        -------
        List[Employee]
            tist of employees to be removed
        """
        if self.mode == 1:
            size = min(self.sample_size, len(solution.employees))
            employees = random.sample(population=solution.employees,
                                      k=size)
        elif self.mode == 2:
            sum_cost = sum(e.objective for e in solution.employees)
            weights = [e.objective/sum_cost for e in solution.employees]
            while True:
                employees1 = random.choices(population=solution.employees,
                                            k=int(self.sample_size/2),
                                            weights=weights
                                            )
                if len(set(employees1)) == int(self.sample_size/2):
                    break
            employees2 = random.sample(population=[e for e in solution.employees if e not in employees1],
                                       k=min(int(self.sample_size/2), len([e for e in solution.employees if e not in employees1])))            
            employees = employees1 + employees2
        return sorted(employees, key=lambda e: e.state.start_shift)

    def get_legs(self, employees: List[Employee]) -> List[BusLeg]:
        """get legs to remove
        Parameters  
        ----------
        employees : List[Employee]
            employees to remove
        """
        legs = [leg for e in employees for leg in e.legs]
        return sorted(legs, key=lambda leg: leg.start)

    def apply(self,
              solution: Solution) -> DestroyOutput:
        logger.debug(f'Starting {self.name} with {self.sample_size} employees')
        employees = self.get_employees(solution)
        legs = self.get_legs(employees)
        initial_solution = [[(leg, False) for leg in e.legs] for e in employees]
        self._remove_employees(solution, employees)
        logger.debug(f'Removed {len(legs)} legs = {[leg.id for leg in legs]}')
        logger.debug(f'Removed {len(employees)} employees = {[e.name for e in employees]}')
        self.stats.rre_runs += 1
        return DestroyOutput(legs,
                             initial_solution=initial_solution,
                             n_employees=len(employees),
                             employees=employees
                             )


class TourRemover(Destroyer):

    def __init__(self, stats: Stats, size: int) -> None:
        super().__init__(stats)
        self.tour = None
        self.solution = None
        self.name = 'TourRemover'
        self.remover_size = size
        self.success = 0

    def get_legs(self,
                 employees: list[Employee]) -> Tuple[List[BusLeg], List[List[Tuple[BusLeg, bool]]]]:
        """Evaluate the union set of the legs considered (legs)
           and the set of the legs of the selected tour (fixed_legs)
        Parameters
        ----------
        employees : list[Employee]
            employees to remove
        Returns
        -------
        Tuple[BusLeg, List[List[BusLeg]]]
            legs and fixed legs to remove    
        """
        legs = [leg for e in employees for leg in e.legs]
        fixed_legs = [[(leg, leg.tour != self.tour) for leg in e.legs] for e in employees]
        legs = sorted(legs, key=lambda leg: leg.start)
        return legs, fixed_legs

    def apply(self,
              solution: Solution) -> DestroyOutput:
        logger.debug(f'Starting {self.name}')
        self.solution = solution.copy()
        employees = set()
        tours = set()
        while len(employees) < self.remover_size:
            self.tour = random.choice(solution.instance.tours)
            if self.tour in tours:
                continue
            tours.add(self.tour)
            employees.update(e for e in solution.employees if self.tour in [leg.tour for leg in e.legs])
        # while len(employees) < 2:
        #     self.tour = random.choice(solution.instance.tours)
        #     employees = [e for e in solution.employees if self.tour in [leg.tour for leg in e.legs]]
        logger.debug(f'Removed tours {[t for t in tours]}')
        self._remove_employees(solution, employees)
        employees = sorted(employees, key=lambda e: min(leg.start for leg in e.legs))
        legs, fixed_legs = self.get_legs(employees)
        logger.debug(f'Removed {len(legs)} legs = {[leg.id for leg in legs]}')
        logger.debug(f'Removed {len(employees)} employees = {[e.name for e in employees]}')
        self.stats.tours_runs += 1
        return DestroyOutput(legs,
                             fixed_legs, len(employees),
                             employees=employees
                             )

