from __future__ import annotations
import csv
import os
from typing import List
from sortedcontainers import SortedList
from pathlib import Path
from employee import Employee


class Instance:
    """This class represents the instance of the BDSP problem.
    """

    def __init__(self, legs: SortedList, distance_matrix: list, start_work: float, end_work: float) -> None:
        self.legs = legs
        self.distance_matrix = distance_matrix
        self.start_work = start_work
        self.end_work = end_work
        self.start_shifts = 0
        self.end_shifts = 0
        self.tours = []
        self.BKS = None
        self.LB = None
        self.BH = None

    @staticmethod
    def from_file(file: str) -> Instance:
        """Read from file

        Parameters
        ----------
        file : str
            Input file in the form "realistic_m_n.csv"

        Returns
        -------
        Instance
            Instance returned.
        """
        instance_folder = Path.home() / 'busdriverschedulingproblem' / 'files' / 'instances'
        with open(f'{instance_folder}/{file}.csv', 'r') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',',
                                    quoting=csv.QUOTE_NONNUMERIC)
            bus_legs = SortedList()
            tours = []
            next(csv_file)
            for line_counter, row in enumerate(csv_reader):
                bus_legs.add(BusLeg(id=line_counter,
                                    tour=int(row[0]),
                                    start=int(row[1]),
                                    end=int(row[2]),
                                    start_pos=int(row[3]),
                                    end_pos=int(row[4])))
                if row[0] not in tours:
                    tours.append(int(row[0]))
        with open(f'{instance_folder}/{file}_dist.csv', 'r') as f:
            csv_reader = csv.reader(f, delimiter=',',
                                    quoting=csv.QUOTE_NONNUMERIC)
            distance_matrix = list(csv_reader)

        with open(f'{instance_folder}/{file}_extra.csv', 'r') as csv_file_extra:
            csv_reader = csv.reader(csv_file_extra, delimiter=',',
                                    quoting=csv.QUOTE_NONNUMERIC)
            start_work = next(csv_reader)
            start_work = [int(x) for x in start_work]
            end_work = next(csv_reader)
            end_work = [int(x) for x in end_work]

        instance = Instance(bus_legs, distance_matrix, start_work, end_work)
        bks_path = Path.home() / 'busdriverschedulingproblem' / 'BKS.csv'
        with bks_path.open('r') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            for row in csv_reader:
                if row[0] == file:
                    instance.LB = float(row[1])
                    instance.BKS = int(row[2])
                    instance.BH = int(row[3])
        instance.start_shifts = min(leg.start for leg in instance.legs)
        instance.end_shifts = max(leg.end for leg in instance.legs)
        instance.tours = sorted(tours)
        return instance

    def get_index(self, leg: BusLeg) -> int:
        """get the index of the leg w.r.t. the instance

        Parameters
        ----------
        leg : BusLeg
            leg considered

        Returns
        -------
        int
            index of the leg. I.e., if legs=[*,*,*,leg], then get_index(leg) = 3
        """
        return self.legs.index(leg)


    def get_passive_ride(self, i: int, j: int) -> int:
        """Return the passive ride time between positions i and j

        Parameters
        ----------
        i : int
            initial position
        j : int
            final position

        Returns
        -------
        int
            time it takes a driver to get from i to j when not actively driving a bus
        """
        return 0 if i == j else self.distance_matrix[i][j]

    def get_diff(self, i:int, j:int) -> int:
        for leg1 in self.legs:
            for leg2 in self.legs:
                if leg1.id == i and leg2.id == j:
                    return leg2.start - leg1.end


class BusLeg:
    """The bus leg class.
    """

    def __init__(self, id: int, tour: int, start: float, end: float, start_pos: int, end_pos: int) -> None:
        self.id = id
        self.tour = tour
        self.start = start
        self.end = end
        self.start_pos = start_pos
        self.end_pos = end_pos
        self.name = id
        self.employee = None
        self.original_index = id
        
    def __str__(self) -> str:
        return str(self.id)

    def __repr__(self) -> str:
        return str(self.id)

    def __hash__(self):
        return hash(self.id)

    def __getitem__(self, item):
        return item
    
    @property
    def drive(self) -> int:
        return self.end - self.start

    def register_employee(self, employee: Employee) -> None:
        self.employee = employee 

    def __eq__(self, other):
        if isinstance(other, BusLeg):
            return self.id == other.id

    def __le__(self, other):
        if isinstance(other, BusLeg):
            return self.id == other.id
        return (self.start < other.start) or (self.start == other.start and self.id < other.id)

    def __lt__(self, other):
        if (self.id is None) or (other is None):
            return self.start < other.start
        return (self.start < other.start or (self.start == other.start and self.id < other.id))

    def __gt__(self, other):
        if (self.id is None) or (other is None):
            return self.start > other.start
        return (self.start > other.start or (self.start == other.start and self.id > other.id))


class Solution:
    """Solution class, represented by a list of employees
    """

    def __init__(self, employees: List[Employee]) -> None:
        if not employees:
            self.employees = []
            self.instance = None
        else:
            self.employees = employees
            self.instance = employees[0].instance
        self.value = 0
        self.change = 0
        self.changing_employees = set()
        self.changing_buslegs = set()
        self.feasible = True

    def evaluate_gap(self) -> float:
        """Evaluate the GAP of the solution

        Returns
        -------
        float
            GAP of the solution from the best LB or the BKS
        """
        return (self.value - self.instance.BKS) / self.instance.BKS*100

    def __iter__(self):
        return iter(self.employees)

    def copy(self) -> Solution:
        """Copy the current solution

        Returns
        -------
        Solution
            Solution copied
        """
        employees_copy = [e.copy() for e in self.employees]
        output = Solution(employees_copy)
        output.value = self.value
        return output

    def set(self, solution: Solution) -> None:
        """Set the solution to the solution given as the argument

        Parameters
        ----------
        solution : Solution
            New solution.
        """
        self.employees = solution.employees    
        self.value = solution.value
        self.change = solution.change
    
    def evaluate(self) -> None:
        """Evaluate the solution
        """
        self.value = 0
        self.feasible = True
        for employee in self.employees:
            self.value += employee.evaluate()
            if employee.state.feasible is False:
                self.feasible = False

    def print_to_file(self, file: str) -> None:
        """  Print the solution into the given file.  
             
             The output format is a binary matrix n x l where:
                n is the number of employee
                l is the number of bus legs (ordered by start time)
                the element (i,j) is 1 if leg j is assigned to employee i, 0 otherwise.

        Parameters
        ----------
        file : str
            Desired output file
            
        """
        data = [[0 for _ in self.instance.legs] for _ in self.employees] 
        with open(file, 'w', newline='') as f:
            writer = csv.writer(f)
            for employee in self.employees:
                for leg in employee.legs:
                    data[employee.id][self.instance.get_index(leg)] = 1
                writer.writerow(data[employee.id])

    @staticmethod
    def from_file(instance: Instance, file: str) -> Solution:
        """Read a solution from file

        Parameters
        ----------
        instance : Instance
            Instance used to read the solution
        file : str
            name of the file, in the form "realistic_m_n_solution.csv"

        Returns
        -------
        Solution
            Solution readed.
        """
        employees: List[Employee] = []
        counter = 0
        file_path = Path.home() / 'busdriverschedulingproblem' / file
        with file_path.open('r') as f:
            f = csv.reader(f, quoting=csv.QUOTE_NONNUMERIC)
            for row in f:
                row_legs = [index for index, value in enumerate(row) if value == 1]
                if not row_legs:
                    continue
                employee = Employee(counter, instance)
                employee.name = f'E{str(counter)}'
                counter += 1
                employees.append(employee)
                for leg in row_legs:
                    employee.add_bus(instance.legs[leg])
        employees = sorted(employees, key=lambda x: x.legs[0].start)
        for i, employee in enumerate(employees):
            employee.id = i
            employee.name = f'E{str(i)}'
        return Solution(employees)

    def represent(self) -> str:
        output = []
        for leg in self.instance.legs:
            for e in self.employees.values():
                if leg in e.legs:
                    output.append(e.id)
        return "".join([f'{str(a)}x' for a in output])
        
    def resort_employees(self) -> None:
        """
        Sort the employee list by the start of their first legs. 
        Empty employees are dropped.
        The sequences are renamed according to the new order.
        """
        intervals = [list(e.legs) for e in self.employees if len(e.legs) > 0]
        intervals.sort(key=lambda seq: seq[0])
        new_employees = []
        for i, _ in enumerate(intervals):
            employee = Employee(str(i), self.instance)
            employee.name = f'E{str(i)}'
            employee.id = i
            for leg in intervals[i]:
                employee.legs.add(leg)
            new_employees.append(employee)
        self.employees = new_employees.copy()
