#!/usr/bin/python3
import time
import random
import os
from typing import List, Optional
import sys
import json
import logging
import argparse
import tempfile
from pathlib import Path
from datetime import date

from utils.io.output_writer import OutputWriter
from repairer import Repairer
from data import Instance
from data import Solution
# from .data.Instance import Instance
# from .data.Solution import Solution
from destroyer import EmployeeRemover, TourRemover, Destroyer
from stats import Stats


def setup_logging(verbose: bool) -> None:
    f = '%(asctime)s|%(levelname)s|%(name)s|%(message)s'
    log_level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(level=log_level, format=f)


def parse_arguments() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description='Reset format')
    parser.add_argument('--instance_file', '-i', required=True, type=str, help='instance file')
    parser.add_argument('--seed', '-s', required=False, type=str, help='random seed')
    parser.add_argument('--tipo', '-t', required=False, type=str, help='static or adaptive')
    parser.add_argument('--destroyers', '-d', nargs="*", required=False, type=str, help='destroyer type', default=["EmployeeWeighted_EmployeeUniform_TourRemover"])
    parser.add_argument('--decay', '-l', required=False, type=float, help='lambda(decay) parameter')
    parser.add_argument('--repairer', '-r', required=False, type=str, help='repairer type: BP or CG')
    parser.add_argument('--verbose', '-v', required=False, help="Be verbose", action='store_true')
    parser.add_argument('--decrease', '-dc', required=False, help="Decrease the size of the destroyers", action='store_true')
    parser.add_argument('--timeout', '-to', required=False, type=int, help="Timeout in seconds")
    parser.add_argument('--size', '-sz', required=False, type=int, help="Destruction size")
    parser.add_argument('--no_improvement', '-ni', required=False, type=int, help="Max iterations without improvement")
    parser.add_argument('--weights', '-w',  nargs='+', required=False)
    return parser.parse_args()


def update_weights(weights: List[float], 
                   config: dict, 
                   decay_parameter: float, 
                   stats: Stats) -> List[float]:
    """Update the weights of the destroyers

    Parameters
    ----------
    weights : List[float]
        weights of the destroyers
    config : dict
        configuration file
    decay_parameter : float
        in [0,1]. It controls how sensitive the weights are
    stats : Stats
        statistics of the ALNS
    """

    return [max(
                config['min_weights'],
                decay_parameter * weights[i] + (1 - decay_parameter) * stats.success_destroy[i] / stats.duration_destroy[i]) 
                    for i in range(len(weights))
            ]


def get_initial_solution(config: dict, instance: Instance, instance_file: str) -> Solution:
    """ Get the initial solution to be used by the ALNS

    Parameters
    ----------
    config : dict
        configuration file
    instance : Instance
        instance to be solved
    instance_file : str
        file name of the instance

    Returns
    -------
    Solution
        initial solution
    """
    if config['start_with_BKS']:
        initial = Solution.from_file(
            instance, f'files/solutions/bks/{instance_file}_solution.csv'
            )
    else:
        initial = Solution.from_file(
            instance, f'files/solutions/initial/{instance_file}_initial.csv'
            )
    initial.evaluate()
    return initial


def check_if_better(current: Solution, best: Solution) -> bool:
    """Check if current solution is better than best solution."""
    return current.value < best.value


def check_if_optimal(current: Solution, lower_bound: float) -> bool:
    """Check if current solution is optimal."""
    return abs(current.value - int(lower_bound)) <= 1e-10


def get_destroyers(stats: Stats, destruction_size: int, destroyers_list: Optional[List[str]]) -> List[Destroyer]:
    """Get the destroyers to be used by the ALNS."""
    Destroyers = {
        "EmployeeWeighted": lambda stats, destruction_size: EmployeeRemover(stats,
                                                mode=EmployeeRemover.WEIGHTED,
                                                sample_size=destruction_size),
        "EmployeeUniform": lambda stats, destruction_size: EmployeeRemover(stats,
                                                mode=EmployeeRemover.UNIFORM,
                                                sample_size=destruction_size),
        "TourRemover": lambda stats, destruction_size: TourRemover(stats, size=destruction_size)
    }
    destroyers_list = destroyers_list[0].split("_")
    return [Destroyers[destroyer](stats, destruction_size) for destroyer in destroyers_list] 


def write_result(output: str, result_file: str) -> None:
    """Write the result of the ALNS to a file.
    
    Parameters
    ----------
    output : str
        output to be printed
        result_file : str
            The path to the file where the output will be written
    """
    with open(result_file, 'a') as f:
        if os.stat(result_file).st_size == 0:  # check if file is empty
            header = (
                "date,instance,algorithm, destroyers, decreasing, seed, destruction_size,time,"
                + "initial_employees,initial_value,initial_gap,"
                + "best_employees,best_value,best_gap_bks, best_gap_bh,"
                + "max_time,avg_time, size_changes,"
                + "iterations, time_last improvement,"
                + "CG_calls, incumbent_changes, feasible_calls, unfeasible_calls,"
                + "success_EW,success_EU,success_TR,"
                + "weight1, weight2, weight3,max_iterations_without_improvement, initial_destruction_size,\n"
            )
            f.write(header)
        f.write(output+"\n")


def run_ALNS():
    args = parse_arguments()
    setup_logging(args.verbose)
    instance_file = args.instance_file
    instance = Instance.from_file(instance_file)
    seed = 42 if args.seed is None else int(args.seed)
    tipo = 'static' if args.decay is None else args.tipo
    logging.info(f'Using {sys.version}')
    very_start = time.perf_counter()
    config_path = Path.home() / 'busdriverschedulingproblem' / 'configuration.json'
    with open(config_path, 'r') as file:
        config = json.load(file)
    if args.size:
        initial_size = args.size
    else:
        initial_size = 10
    destruction_size = initial_size
    if args.decay is None:
        decay_parameter = config['decay_parameter']
    else:
        decay_parameter = args.decay
    destroyer_names = str(args.destroyers).replace("[", "").replace("]", "").replace("'",  "")
    max_no_improvement = args.no_improvement if args.no_improvement is not None else 10
    trajectory_file = Path.home() / 'busdriverschedulingproblem' / 'experiments' / 'trajectories' / f'{instance_file}_{seed}_{tipo}_{decay_parameter}_{destroyer_names}_{max_no_improvement}.csv'
    solution_file = Path.home() / 'busdriverschedulingproblem' / 'experiments' / 'alns' / f'{instance_file}_{seed}_{tipo}_{decay_parameter}_{destroyer_names}_{max_no_improvement}.csv'
    stats = Stats(config, trajectory_file)
    random.seed(seed)
    output_writer = OutputWriter(instance_file, stats)
    # output_writer.create_output_directory()
    tempo = str(time.perf_counter() - very_start).replace('.', '_')
    solution_out = tempfile.NamedTemporaryFile(suffix='_out.csv', delete=False).name
    file_name = tempfile.NamedTemporaryFile(suffix='_alns.csv', delete=False).name

    if args.repairer == 'BP':
        max_branching = None
        repairer_name = 'BP'
    else:
        repairer_name = 'CG'
        max_branching = 0
    
    repairer_CG = Repairer(
        instance_file,
        config,
        stats, instance,
        destruction_size,
        file_name=file_name,
        max_branching=max_branching,
        verbose=args.verbose
        )
    
    logging.info(f'Executing ALNS (lambda = {decay_parameter}) on {instance_file}')
    logging.info(f'Parameters: seed = {seed}, destruction_size = {destruction_size}, destroyers = {args.destroyers}')
    logging.info(f'lambda = {decay_parameter}, repairer = {repairer_name}')

    if tipo == 'static':
        logging.info('ALNS mode: static weights')
    elif tipo == 'adaptive':
        logging.info('ALNS mode: weighted weights')

    initial = get_initial_solution(config, instance, instance_file)
    stats.save(initial.value)
    logging.info(
        f'Initial solution: {initial.value} (GAP = {initial.evaluate_gap():.2f}%) with {len(initial.employees)} employees')
    logging.info(f'Lower Bound = {instance.LB}\tBest-Known-Solution = {instance.BKS} (Optimality GAP = {(instance.BKS - instance.LB)/(instance.LB)*100:2f}%)\t')
    
    current = initial.copy()
    best = initial.copy()
    destroyers = get_destroyers(stats, destruction_size, args.destroyers)
    
    stats.success_destroy = [0 for _ in destroyers]
    stats.runs_destroy = [0 for _ in destroyers]
    stats.name_destroy = [d.name for d in destroyers]
    stats.duration_destroy = [.01 for _ in destroyers]
    ind_destroyers = list(range(len(destroyers)))
    if args.weights is None:
        weights = [1 / len(destroyers) for _ in destroyers]
    else:
        weights = [float(w) for w in args.weights]
    no_improvement = 0
    iteration = 0

    while True:
        if args.timeout is not None and time.perf_counter() - very_start > args.timeout:
            logging.info(f'Timeout of {args.timeout} seconds reached.')
            break

        if stats.is_budget_exhausted():
            logging.info("Configuration budget is exhausted.")
            break

        if no_improvement >= max_no_improvement and args.decrease:
            stats.size_changes += 1
            no_improvement = 0
            logging.debug(f'Max iterations without improvements reached. Changing size after {no_improvement} iterations. Current size = {destruction_size}')
            destruction_size = min(destruction_size + 1, config['max_destruction_size'])
            for destroyer in destroyers:
                destroyer.sample_size = destruction_size

        destroyer_index = random.choices(ind_destroyers, weights)[0]
        start = time.perf_counter()
        output = destroyers[destroyer_index].apply(current)
        stats.runs_destroy[destroyer_index] += 1
        current = repairer_CG.apply(current, output)
        iteration += 1
        end = time.perf_counter()

        if current is None:
            current = best.copy()
            logging.warning(f'start again with the best-so-far solution {best.value=} (GAP = {best.evaluate_gap():.2f}%)')
        else:
            logging.debug(
                f'{current.value=} (GAP = {current.evaluate_gap():.2f}%) with e={len(current.employees)}.\t {best.value=} (e={len(best.employees)},GAP={best.evaluate_gap():.2f})%')
            if check_if_better(current, best):
                no_improvement = 0
                if args.decrease:
                    logging.debug(f'Foudn a better solution. Changing size after {no_improvement} iterations. Current size = {destruction_size}')
                    destruction_size = initial_size
                logging.info(
                    f'Time elapsed: {time.perf_counter()-very_start:.2f}\t Iteration {iteration}\t new solution ({current.value}, GAP = {current.evaluate_gap():.2f}%) is better than the best-so-far ({best.value}, GAP = {best.evaluate_gap():.2f}%) ')
                stats.incumbent_changed += 1
                stats.time_best_solution = time.perf_counter() - stats.start_time
                stats.success_destroy[destroyer_index] += 1
                best = current.copy()
                stats.save(current.value)

                if args.verbose:
                    stats.print()

                if check_if_optimal(current, instance.LB):
                    logging.info(' ========================================================================')
                    logging.info(f'|\t OPTIMAL SOLUTION FOUND: LB = {instance.LB} \tOBJECTIVE = {current.value}')
                    logging.info(' ========================================================================')
                    break
            else:
                no_improvement += 1
                current = best.copy()

        stats.duration_destroy[destroyer_index] += end - start

        if tipo == 'adaptive':
            weights = update_weights(weights, config, decay_parameter, stats)

        logging.debug(f'Weights = {weights}')

    stats.print()
    logging.info(
        f'Initial value = {initial.value} (e = {len(initial.employees)}, GAP = {initial.evaluate_gap():.2f}%),\tFinal value = {best.value} (e = {len(best.employees)}, GAP = {best.evaluate_gap():.2f}%)')
    
    try:
        best.print_to_file(solution_file)
    except FileNotFoundError:
        logging.exception(f'File {solution_file} not found')

    if tipo == 'static':
        method = 'static'        
    elif tipo == 'adaptive':
        if decay_parameter == 0.5:
            method = 'adaptive_avg'
        elif decay_parameter > 0.5:
            method = 'adaptive_large'
        else:
            method = 'adaptive_small'


    # normalise the weights
    weights = [w/sum(weights) for w in weights]
    today = date.today().strftime("%Y/%m/%d")
    output = f"{today}, {instance_file.split('_')[-1]},"
    output += f"{method},{args.destroyers}, {args.decrease}, {seed}, "
    output += f"{destruction_size}, {time.perf_counter() - very_start},"
    output += f"{len(initial.employees)}, {initial.value},{initial.evaluate_gap()},"
    output += f"{len(best.employees)},{best.value},{best.evaluate_gap()},{(best.value - instance.BH)/instance.BH*100},"
    output += f"{max(stats.time_used)},{sum(stats.time_used)/len(stats.time_used)}, {stats.size_changes},"
    output += f"{iteration}, {stats.time_best_solution}," 
    output += f"{stats.bp_calls}, {stats.incumbent_changed}, {stats.feasible_runs_bp}, {stats.infeasible_runs_bp},"
    output += f"{','.join(str(stats.success_destroy[i])  for i in range(len(destroyers)))},"
    output += f"{','.join(str(i) for i in weights)},"
    output += f"{max_no_improvement}, {initial_size}"
    result_path = Path.home() / 'busdriverschedulingproblem' / 'experiments' / f'{instance_file}'
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    result_file = result_path / f'results_{method}_{args.destroyers}_{args.decrease}_{seed}.csv'
    write_result(output, result_file)


    logging.info('Completely finished after %.2f seconds.' % (time.perf_counter() - very_start))
    sys.exit(0)


if __name__ == '__main__':
    run_ALNS()