from __future__ import annotations

from sortedcontainers import SortedList
from typing import List
# from data import BusLeg

EMPLOYEE_D_MAX = 9*60
EMPLOYEE_W_MAX = 10*60
EMPLOYEE_W_MIN = round(6.5*60)
EMPLOYEE_T_MAX = 14*60


class Employee:
    """ Class that represents the Employee (or Shift).
    """

    ID = 1

    def __init__(self, id: int, instance) -> None:
        self.id = id
        Employee.ID += 1
        self.legs = SortedList()
        self.state = State(self)
        # self.previous_state = State(self)
        self.instance = instance
        self.objective = 0
        # self.previous_objective = 0
        self.name = 'E' + str(self.id)
        # self.remove = []
        self.feasible = True
   
    def revert(self) -> None:
        self.objective = self.previous_objective
        self.state = self.previous_state

    def add_bus(self, leg) -> None:
        self.legs.add(leg)
        leg.register_employee(self)

    def evaluate(self):
        """ Evaluate the objective function of the current employee  """
        self.previous_state = self.state
        self.previous_objective = self.objective
        self.state = State(self)
        self.objective = self.state.evaluate()
        return self.objective

    def _eq_(self, other):
        if isinstance(other, Employee):
            # return self.id == other.id
            return self.legs == other.legs
        return False

    def __iter__(self):
        return iter(self.legs)
    
    def copy(self) -> Employee:
        output = Employee(self.id, self.instance)
        output.legs = self.legs.copy()
        output.objective = self.objective
        output.state = self.state
        return output

    def evaluation_abort(self, new_leg):
        """ Evaluate the infeasiblity of adding new_leg.
        If we have:
            ..., previous_leg.end) (new_leg.start, new_leg.end) (next_leg.start, ...
            start = the start of the leg just before new_leg
            end = the start of the leg just after new_leg

        )
        """
        start = self.legs.bisect_left(new_leg) - 1
        end = start + 2
        if start == -1:
            return False
        if end >= len(self.legs):
            return False
        previous_leg = self.legs[start]
        next_leg = self.legs[end]
        if new_leg.start < previous_leg.end:
            return True  
        if next_leg.start < new_leg.end:
            return True


class Constraints:
    """ Class that represents a constraint. """

    def __init__(self, name: str, category: int, weight: float, value: float) -> None:
        """ 
        :param name: name of the constraint.
        :param category: 0 if hard, 1 if soft.
        :param weight: the penalty of the constraint.
        :param value: the current value of the constraint
        """
        
        self.name = name
        self.category = category
        self.weight = weight
        self.value = value


class State:
    def __init__(self, employee: Employee):
        self.feasible = True
        self.employee = employee
        self.MultiValue = {}
        self.constraints = []
        self.work_time = 0
        self.start_shift = 10**(20)
        self.end_shift = 0
        self.start = 0
        self.end = 0
        self.bus_penalty = 0
        self.drive_penalty = 0
        self.drive_time = 0
        self.rest_penalty = 0
        self.rest = 0
        self.first15 = False
        self.break30 = False
        self.center30 = False
        self.unpaid = 0
        self.ride = 0
        self.change = 0
        self.split = 0
        self.split_time = 0
        self.objective = 0
        self.total_time = 0
        self.upmax = 0

    def evaluate_drive_penalty(self):
        penalty = 0
        block = False
        b_20 = 0
        b_15 = 0
        dc = self.employee.legs[0].drive
        for k in range(len(self.employee.legs)-1):
            leg_i = self.employee.legs[k]
            leg_j = self.employee.legs[k+1]
            diff = leg_j.start - leg_i.end
            block = (diff >= 30) or (diff >= 20 and b_20 == 1) or (diff >= 15 and b_15 == 2)
            if block:
                dc = leg_j.drive
                b_20 = 0
                b_15 = 0
            else:
                dc += leg_j.drive
                if diff >= 20:
                    b_20 = 1
                if diff >= 15:
                    b_15 += 1
            if dc >= 4*60:
                penalty += (dc - 4*60)
        self.drive_penalty = penalty
    
    def evaluate_rest_penalty(self):
        """ Evaluate the number of minutes that violates the rest break rules:
            - If work time is less than 6 hours --> return 0
            - If employees breaks last less than 30 minute --> return max(0, work_time - 6*60-1)
            - Else If you do less than 45 minute break --> return max(0, work_time - 9*60) 
        """
        self.rest_penalty = 0
        rest_time = 0
        if self.work_time < 6*60:
            self.rest_penalty = 0
            exit
        # Check break30
        if self.break30 is False or self.first15 is False:
            rest_time = 0
        else:
            for t in range(len(self.employee.legs)-1):
                leg_i = self.employee.legs[t]
                leg_j = self.employee.legs[t+1]
                i = leg_i.end_pos
                j = leg_j.start_pos
                r = self.employee.instance.get_passive_ride(i, j)
                diff = leg_j.start - leg_i.end
                if diff - r >= 3*60:
                    continue
                elif diff - r >= 15:
                    rest_time += diff - r
        if rest_time < 30:
            self.rest_penalty = max(0, self.work_time - (6*60 - 1))
        elif rest_time < 45:
            self.rest_penalty = max(0, self.work_time - 9*60)  

    def evaluate_first15(self) -> None:
        split_time = 0
        for key, _ in enumerate(self.employee.legs[:-1]):
            leg_i = self.employee.legs[key]
            leg_j = self.employee.legs[key+1]
            i = int(leg_i.end_pos)
            j = int(leg_j.start_pos)
            ride = self.employee.instance.get_passive_ride(i, j)
            diff = leg_j.start - leg_i.end
            if diff - ride >= 180:
                split_time += diff - ride
                continue
            diff_1 = diff - ride
            if diff_1 >= 15 and leg_i.end <= self.start_shift + 6*60 + split_time:
                self.first15 = True      

    def evaluate_break30(self) -> None:
        for key, _ in enumerate(self.employee.legs[:-1]):
            leg_i = self.employee.legs[key]
            leg_j = self.employee.legs[key+1]
            i = int(leg_i.end_pos)
            j = int(leg_j.start_pos)
            ride = self.employee.instance.get_passive_ride(i, j)
            if leg_j.start - leg_i.end - ride >= 180:
                continue
            if leg_j.start - leg_i.end - ride >= 30:
                self.break30 = True

    def evaluate_unpaid(self) -> None:
        """ Evaluate the number of minutes of unpaid breaks
        """
        for key, _ in enumerate(self.employee.legs[:-1]):
            leg_i = self.employee.legs[key]
            leg_j = self.employee.legs[key+1]
            i = int(leg_i.end_pos)
            j = int(leg_j.start_pos)
            ride = self.employee.instance.get_passive_ride(i, j)
            if leg_j.start - leg_i.end - ride >= 180:
                continue
            if min(self.end_shift - 3*60, leg_j.start - ride) - max(self.start_shift + 3*60, leg_i.end) >= 30:
                self.center30 = True
            diff_1 = leg_j.start - leg_i.end - ride
            if diff_1 >= 15 and leg_i.end <= self.start_shift + 6*60:
                self.first15 = True
            breakEnd = min(self.end_shift - 2*60, leg_j.start - ride)
            breakStart = max(self.start_shift + 2*60, leg_i.end)
            if breakEnd - breakStart >= 15:
                self.unpaid += round(breakEnd - breakStart)
        
    def evaluate_upmax(self):
        if self.break30 is False or self.first15 is False:
            self.upmax = 0
        elif self.center30:
            self.upmax = 90
        else:
            self.upmax = 60

    def evaluate_split(self) -> None:
        for key, leg in enumerate(self.employee.legs[:-1]):
            leg_i = self.employee.legs[key]
            leg_j = self.employee.legs[key+1]
            i = int(leg_i.end_pos)
            j = int(leg_j.start_pos)
            r = int(self.employee.instance.get_passive_ride(i, j))
            diff = leg_j.start - leg_i.end
            if (diff - r >= 180):
                self.split += 1
                self.split_time += diff - r

    def evaluate(self):
        if not self.employee.legs:
            return 0
        self.bus_penalty = 0
        self.drive_time = 0
        self.change = 0
        self.split = 0
        self.ride = 0
        first_leg = self.employee.legs[0]
        self.start_shift = first_leg.start - self.employee.instance.start_work[first_leg.start_pos]
        last_leg = self.employee.legs[-1]
        self.end_shift = last_leg.end + self.employee.instance.end_work[last_leg.end_pos]
        self.total_time = self.end_shift - self.start_shift
        for key, leg in enumerate(self.employee.legs[:-1]):
            self.drive_time += leg.drive
            leg_i = self.employee.legs[key]
            leg_j = self.employee.legs[key+1]
            i = leg_i.end_pos
            j = leg_j.start_pos
            r = self.employee.instance.get_passive_ride(i, j)
            self.ride += r
            diff = leg_j.start - leg_i.end
            if leg_i.tour != leg_j.tour or leg_i.end_pos != leg_j.start_pos:
                if diff - self.employee.instance.distance_matrix[i][j] < 0:
                    self.bus_penalty -= (diff - self.employee.instance.distance_matrix[i][j])
                    self.feasible = False
                    # return None
                elif diff <= 0:
                    self.bus_penalty -= diff
                if leg_i.tour != leg_j.tour:
                    self.change += 1
        self.drive_time += self.employee.legs[-1].drive
        self.evaluate_first15()
        self.evaluate_break30()
        self.evaluate_unpaid()
        self.evaluate_upmax()
        self.evaluate_split()
        self.evaluate_drive_penalty()
        self.evaluate_rest_penalty()
        self.work_time = round(self.total_time - self.split_time - min(self.unpaid, self.upmax))
        self.work_time = max(self.work_time, 390)
        # self.start_shift = self.legs
        ## Constraint adding:
        # self.constraints.append(Constraints('Max(bus_chain_penalty)', 0, 1000, self.bus_penalty))
        # self.constraints.append(Constraints('Max(drive_time)', 0, 1000, max(drive_time - EMPLOYEE_D_MAX, 0)))
        # self.constraints.append(Constraints('Max(span)', 0, 1000, max(self.total_time - EMPLOYEE_T_MAX, 0)))
        # self.constraints.append(Constraints('Max(span)', 1, 1, self.total_time))
        # self.constraints.append(Constraints('Max(tour_changes)', 1, 30, self.change))
        # self.constraints.append(Constraints('Max(ride_time)', 1, 1, self.ride))
        # self.constraints.append(Constraints('Max(drive penalty)', 0, 1000, self.drive_penalty))
        # self.constraints.append(Constraints('Max(rest penalty)', 0, 1000, self.rest_penalty))
        # # self.constraints.append(Constraints('Max(work_time)', 0, 1000, max(self.work_time - EMPLOYEE_W_MAX, 0)))  
        # self.constraints.append(Constraints('Val(work_time)', 1, 2, self.work_time))
        # # self.constraints.append(Constraints('Min(work_time)', 1, 2, max(EMPLOYEE_W_MIN - self.work_time, 0)))
        # self.constraints.append(Constraints('Max(shift_split)', 1, 180, self.split))
        # output = self.finalSum()
        # hard, soft = self.finalSum()
        # hard = 
        self.objective = round(2*self.work_time + self.total_time \
                               + self.ride + 30*self.change + 180*self.split)
        hard_constraints = 1000*(self.bus_penalty +
                                 max(self.drive_time - EMPLOYEE_D_MAX, 0) +
                                 max(self.total_time - EMPLOYEE_T_MAX, 0) +
                                 self.drive_penalty + self.rest_penalty +
                                 max(self.work_time - EMPLOYEE_W_MAX, 0)
                                 )
        # self.MultiValue = {0: hard, 1: soft}
        if hard_constraints > 0:
            self.feasible = False
        return round(hard_constraints + self.objective)

    # def finalSum(self) -> List[int]:
    #     s_0 = 0
    #     s_1 = 0
    #     for con in self.constraints:
    #         if con.category == 1:
    #             s_1 += con.weight * con.value
    #         elif con.category == 0:
    #             s_0 += con.weight * con.value
    #     return round(s_0), round(s_1)

    def copy(self):
        employee_copy = self.employee.copy()
        new_state = State(employee_copy)
        new_state.__dict__.update(self.__dict__)
        return new_state

