import logging
import random
import time
from tqdm import tqdm

import torch
import numpy as np
from logging import getLogger
import copy

from .env import Env
from .model import Model
# from SISRs import ruin

from .logging_utils import *

import itertools
import csv


MEASURE_TIME = False #True
USE_TORCH_COMPILE = False # provides few % speedup, like 1 or 2% but only after some iterations
MUTE_TQDM = True


class Search:
    def __init__(self,
                 env_params,
                 tester_params):
        

        print(f"{50*'='}\nUsing SA Agent\n{50*'='}")

        # save arguments
        self.env_params = env_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='tester')
        self.result_folder = get_result_folder()


        global SISRs
        if env_params['problem'] == "cvrp":
            from .cpp.cvrp import SISRs
        elif env_params['problem'] == "vrptw":
            from .cpp.vrptw import SISRs
        elif env_params['problem'] == "pcvrp":
            from .cpp.pcvrp import SISRs
        else:
            raise NotImplementedError

        self.deconstruction_operator = []


        # ENV
        self.env = Env(False, **self.env_params)

        # utility
        self.time_estimator = TimeEstimator()


    def run(self):
        self.time_estimator.reset()

        costs_AM = AverageMeter()
        rt_AM = AverageMeter()
        nb_iter_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            extension = os.path.splitext(self.tester_params['test_data_load']['filename'])[1]
            if extension == ".pkl":
                self.env.load_problem_dataset_pkl(self.tester_params['test_data_load']['filename'],
                                                self.tester_params['test_episodes'])
            elif extension == ".pt":
                self.env.load_problem_dataset_pt(self.tester_params['test_data_load']['filename'], "cpu") # note: interested only on CPU
            else:
                raise NotImplementedError

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        # print(test_num_episode)
        test_num_episode = self.tester_params.get("max_episodes", test_num_episode)
        print(f"Number of episodes to test: {test_num_episode}")

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)
            # batch_size = 4

            # print(f"batch_size: {batch_size}")

            aug_factor = self.tester_params['aug_factor']


            costs, rt, nb_iter, solutions = self._test_one_batch(batch_size, self.tester_params['nb_iterations'], episode,
                                                                  aug_factor=aug_factor)

            costs_AM.update(costs.mean(), batch_size)
            rt_AM.update(rt, batch_size)
            nb_iter_AM.update(nb_iter, batch_size)



            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score: {:.3f}, running_mean: {:.3f} iter: {:.1f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, costs.mean(), costs_AM.avg, nb_iter))

            with open(os.path.join(self.result_folder, "results.csv"), mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), list(costs), [rt / batch_size] * batch_size, [nb_iter] * batch_size))

            with open(os.path.join(self.result_folder, "solutions.csv"), mode='a', newline='') as file:
                tours_batch = [s.getTourList() for s in solutions]
                tours_batch = [[[0, *t, 0] for t in tours_instance] for tours_instance in tours_batch]
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), tours_batch))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" AVG. COSTS: {:.4f} ".format(costs_AM.avg))
                self.logger.info(" AVG. RUNTIME: {:.2f} ".format(rt_AM.avg))
                self.logger.info(" AVG. ITERATIONS: {:.1f} ".format(nb_iter_AM.avg))

    def _test_one_batch(self, batch_size, nb_iterations, episode, aug_factor=1):
        rollout_size = self.tester_params['rollout_size']
        max_runtime = self.tester_params['max_runtime']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']
        delta = self.tester_params['SA_delta']

        runtime_limited = not max_runtime <= 0


        # Ready
        ###############################################
        for model in self.deconstruction_operator:
            model['model'].eval()

        with torch.inference_mode():
        # with torch.no_grad():

            start_time = time.time()

            self.env.init_instances(batch_size, rollout_size, "cpu", aug_factor) # Note: interested only on CPU

            incumbent_costs = np.full(batch_size, np.inf)
            incumbent_solutions = [None] * batch_size

            T_0 = self.tester_params['SA_start_T']
            T_f = self.tester_params['SA_final_T']
            T = T_0

            if not runtime_limited:
                c = (T_f / T_0) ** (1 / nb_iterations)

            # tqdm progress bar for iterations
            for i in tqdm(range(nb_iterations), total=nb_iterations, desc='SA Iterations', disable=MUTE_TQDM):

                if MEASURE_TIME:
                    loop_start = time.time()
                new_solutions = [None] * (batch_size * aug_factor)
                new_costs = np.zeros((batch_size * aug_factor))

                for b_idx in range(batch_size * aug_factor):  # 1 batch - one problem instance
                    solution = self.env.instanceSet.get_solution(b_idx)
                    if self.tester_params.get("use_agent", True):
                        sel_nodes = SISRs.agents_heuristic_deconstruction_selection(solution, self.env.num_nodes_to_remove, rollout_size)
                        sol, _ = SISRs.agents_remove_recreate_allImp(solution, sel_nodes,
                                                    beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)
                    else:
                        sel_nodes = SISRs.heuristic_deconstruction_selection(solution, self.env.num_nodes_to_remove, rollout_size)
                        sol, _ = SISRs.remove_recreate_allImp(solution, sel_nodes,
                                                    beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)
                    new_solutions[b_idx] = sol
                    new_costs[b_idx] = sol.totalCosts

                    if sol.totalCosts < incumbent_costs[b_idx % batch_size]:
                        incumbent_costs[b_idx % batch_size] = sol.totalCosts
                        incumbent_solutions[b_idx % batch_size] = sol

                
                if MEASURE_TIME:
                    loop_end = time.time()
                    loop_time = loop_end - loop_start
                    print(f"Iteration {i+1}/{nb_iterations}: For loop time: {loop_time:.4f}s")

                for b_idx in range(batch_size * aug_factor):
                    self.env.instanceSet.set_solution(b_idx, new_solutions[b_idx])

                if runtime_limited:
                    cur_runtime = time.time() - start_time
                    T = T_f * (T_0 / T_f) ** (1 - (cur_runtime / max_runtime))
                    if cur_runtime > max_runtime:
                        break
                else:
                    T *= c


            if runtime_limited:
                assert time.time() - start_time > max_runtime, "A runtime limit was set, but search terminated based on number of iterations."

            return incumbent_costs, time.time() - start_time, i + 1, incumbent_solutions
