import logging
import random
import time
from tqdm import tqdm

import torch
import numpy as np
from logging import getLogger
import copy

from .env import Env
from .model import Model
# from SISRs import ruin

from .logging_utils import *

import itertools
import csv
import os


MEASURE_TIME = False #True
USE_TORCH_COMPILE = False # provides few % speedup, like 1 or 2% but only after some iterations
MUTE_TQDM = True


class Search:
    def __init__(self,
                 env_params,
                 tester_params):
        

        print(f"{50*'='}\nUsing SA Agent\n{50*'='}")

        # save arguments
        self.env_params = env_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='tester')
        self.result_folder = get_result_folder()


        global SISRs
        if env_params['problem'] == "cvrp":
            from .cpp.cvrp import SISRs
        elif env_params['problem'] == "vrptw":
            from .cpp.vrptw import SISRs
        elif env_params['problem'] == "pcvrp":
            from .cpp.pcvrp import SISRs
        else:
            raise NotImplementedError

        self.deconstruction_operator = []


        # ENV
        self.env = Env(False, **self.env_params)

        # utility
        self.time_estimator = TimeEstimator()
        # print(f"{50*'='}\nHi\n{50*'='}")
        # print("Hi from search_sa_2.py")
        # print(asseself.env)
        # assert False, self.env

    def run(self):
        self.time_estimator.reset()

        costs_AM = AverageMeter()
        rt_AM = AverageMeter()
        nb_iter_AM = AverageMeter()
        import os # TODO: remove this
        if self.tester_params['test_data_load']['enable']:
            extension = os.path.splitext(self.tester_params['test_data_load']['filename'])[1]
            if extension == ".pkl":
                self.env.load_problem_dataset_pkl(
                    self.tester_params['test_data_load']['filename'],
                    self.tester_params['test_episodes']
                )
            elif extension == ".pt":
                # interested only in CPU
                self.env.load_problem_dataset_pt(
                    self.tester_params['test_data_load']['filename'],
                    "cpu"
                )
            else:
                raise NotImplementedError

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:
            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)
            aug_factor = self.tester_params['aug_factor']

            # ---------------------------------------------------------------------
            # Parallel execution (one CPU core per problem instance)
            # ---------------------------------------------------------------------
            if batch_size > 1:
                from concurrent.futures import ProcessPoolExecutor, TimeoutError
                from .search_mp import _worker_test_one_batch
                import multiprocessing as mp
                import os
                import copy
                import numpy as np

                mp.set_start_method("spawn", force=True)

                # Light, picklable snapshot of the deconstruction operator
                decon_sd = []
                for item in self.deconstruction_operator:
                    decon_sd.append({
                        "cfg": copy.deepcopy(item["model"].cfg),
                        "weights": copy.deepcopy(item["model"].state_dict())
                    })

                worker_budget = self.tester_params['max_runtime']
                worker_timeout = None if worker_budget <= 0 else worker_budget + 5  # small grace margin

                with ProcessPoolExecutor(
                        max_workers=min(batch_size, os.cpu_count())
                    ) as pool:

                    futures = [
                        pool.submit(
                            _worker_test_one_batch,
                            self.env_params,
                            self.tester_params,
                            decon_sd,
                            episode + k,
                            None     # optional per-worker seed
                        )
                        for k in range(batch_size)
                    ]

                    results = []
                    for fut in futures:
                        try:
                            results.append(fut.result(timeout=worker_timeout))
                        except TimeoutError:
                            # Worker exceeded its own limit
                            self.logger.warning("Worker timed out; assigning inf cost.")
                            results.append((np.inf, worker_timeout or 0.0, 0, None))

                costs = np.array([r[0] for r in results], dtype=np.float32)
                rt = sum(r[1] for r in results)
                nb_iter = np.mean([r[2] for r in results])
                solutions = [r[3] for r in results]

            # ---------------------------------------------------------------------
            # Serial fallback (batch_size == 1)
            # ---------------------------------------------------------------------
            else:
                costs, rt, nb_iter, solutions = self._test_one_batch(
                    batch_size,
                    self.tester_params['nb_iterations'],
                    episode,
                    aug_factor=aug_factor
                )

            costs_AM.update(costs.mean(), batch_size)
            rt_AM.update(rt, batch_size)
            nb_iter_AM.update(nb_iter, batch_size)

            episode += batch_size

            # -----------------------------------------------------------------
            # Logging & CSV output (unchanged)
            # -----------------------------------------------------------------
            elapsed, remain = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info(
                "episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score: {:.3f}, "
                "running_mean: {:.3f} iter: {:.1f}".format(
                    episode, test_num_episode, elapsed, remain,
                    costs.mean(), costs_AM.avg, nb_iter
                )
            )

            with open(os.path.join(self.result_folder, "results.csv"), mode='a', newline='') as f:
                writer = csv.writer(f)
                writer.writerows(
                    zip(
                        list(range(episode - batch_size + 1, episode + 1)),
                        list(costs),
                        [rt / batch_size] * batch_size,
                        [nb_iter] * batch_size
                    )
                )

            with open(os.path.join(self.result_folder, "solutions.csv"), mode='a', newline='') as f:
                tours_batch = [s.getTourList() if s else [] for s in solutions]
                tours_batch = [[[0, *t, 0] for t in tours_inst] for tours_inst in tours_batch]
                writer = csv.writer(f)
                writer.writerows(
                    zip(
                        list(range(episode - batch_size + 1, episode + 1)),
                        tours_batch
                    )
                )

            if episode == test_num_episode:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" AVG. COSTS: {:.4f} ".format(costs_AM.avg))
                self.logger.info(" AVG. RUNTIME: {:.2f} ".format(rt_AM.avg))
                self.logger.info(" AVG. ITERATIONS: {:.1f} ".format(nb_iter_AM.avg))
                
    def _test_one_batch(self, batch_size, nb_iterations, episode, aug_factor=1):
        rollout_size = self.tester_params['rollout_size']
        max_runtime = self.tester_params['max_runtime']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']
        delta = self.tester_params['SA_delta']

        runtime_limited = not max_runtime <= 0


        # Ready
        ###############################################
        for model in self.deconstruction_operator:
            model['model'].eval()

        with torch.inference_mode():
        # with torch.no_grad():

            start_time = time.time()

            self.env.init_instances(batch_size, rollout_size, "cpu", aug_factor) # Note: interested only on CPU

            incumbent_costs = np.full(batch_size, np.inf)
            incumbent_solutions = [None] * batch_size

            T_0 = self.tester_params['SA_start_T']
            T_f = self.tester_params['SA_final_T']
            T = T_0

            if not runtime_limited:
                c = (T_f / T_0) ** (1 / nb_iterations)

            # tqdm progress bar for iterations
            for i in tqdm(range(nb_iterations), total=nb_iterations, desc='SA Iterations', disable=MUTE_TQDM):

                if MEASURE_TIME:
                    loop_start = time.time()
                new_solutions = [None] * (batch_size * aug_factor)
                new_costs = np.zeros((batch_size * aug_factor))

                for b_idx in range(batch_size * aug_factor):  # 1 batch - one problem instance
                    solution = self.env.instanceSet.get_solution(b_idx)
                    sel_nodes = SISRs.heuristic_deconstruction_selection(solution, self.env.num_nodes_to_remove, rollout_size)
                    sol, _ = SISRs.remove_recreate_allImp(solution, sel_nodes,
                                                    beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)
                    new_solutions[b_idx] = sol
                    new_costs[b_idx] = sol.totalCosts

                    if sol.totalCosts < incumbent_costs[b_idx % batch_size]:
                        incumbent_costs[b_idx % batch_size] = sol.totalCosts
                        incumbent_solutions[b_idx % batch_size] = sol

                
                if MEASURE_TIME:
                    loop_end = time.time()
                    loop_time = loop_end - loop_start
                    print(f"Iteration {i+1}/{nb_iterations}: For loop time: {loop_time:.4f}s")

                for b_idx in range(batch_size * aug_factor):
                    self.env.instanceSet.set_solution(b_idx, new_solutions[b_idx])

                if runtime_limited:
                    cur_runtime = time.time() - start_time
                    T = T_f * (T_0 / T_f) ** (1 - (cur_runtime / max_runtime))
                    if cur_runtime > max_runtime:
                        break
                else:
                    T *= c


            if runtime_limited:
                assert time.time() - start_time > max_runtime, "A runtime limit was set, but search terminated based on number of iterations."

            return incumbent_costs, time.time() - start_time, i + 1, incumbent_solutions
