import logging
import random
import pickle
import time
from tqdm import tqdm

import torch
import numpy as np
from logging import getLogger
import copy

from .env import Env
from .model import Model
# from SISRs import ruin

from .logging_utils import *

import itertools
import csv


MEASURE_TIME = False #True
USE_TORCH_COMPILE = False # provides few % speedup, like 1 or 2% but only after some iterations
USE_SISR_PARALLEL = False  # Set to True to enable parallel SISRs loop
TQDM_DISABLE = True

class Search:
    def __init__(self,
                 env_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='tester')
        self.result_folder = get_result_folder()

        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        global SISRs
        if env_params['problem'] == "cvrp":
            from .cpp.cvrp import SISRs
        elif env_params['problem'] == "vrptw":
            from .cpp.vrptw import SISRs
        elif env_params['problem'] == "pcvrp":
            from .cpp.pcvrp import SISRs
        else:
            raise NotImplementedError

        self.deconstruction_operator = []

        if not self.tester_params['use_baseline_destroy']:
            for model_data in self.tester_params['model_load']:
                checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_data)
                checkpoint = torch.load(checkpoint_fullname, map_location=device, weights_only=False)
                model_params = checkpoint['model_params']

                model = Model(**model_params)
                model.load_state_dict(checkpoint['model_state_dict'])
                if USE_TORCH_COMPILE:
                    try:
                        model = torch.compile(model)
                    except AttributeError:
                        print("torch.compile is not available in this version of torch.")
                binary_vector_pool = torch.Tensor(
                    [list(i) for i in itertools.product([0, 1], repeat=model_params['z_dim'])])
                self.deconstruction_operator.append({'model': model, 'binary_vector_pool': binary_vector_pool, **model_data})
                assert checkpoint['env_params']['num_nodes_to_remove'] == model_data['node_to_remove']

        # ENV
        self.env = Env(False, **self.env_params)

        # if USE_TORCH_COMPILE:
        #     self.env = torch.compile(self.env)

        # utility
        self.time_estimator = TimeEstimator()




    def run(self):
        self.time_estimator.reset()

        costs_AM = AverageMeter()
        rt_AM = AverageMeter()
        nb_iter_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            extension = os.path.splitext(self.tester_params['test_data_load']['filename'])[1]
            if extension == ".pkl":
                self.env.load_problem_dataset_pkl(self.tester_params['test_data_load']['filename'],
                                                self.tester_params['test_episodes'])
            elif extension == ".pt":
                self.env.load_problem_dataset_pt(self.tester_params['test_data_load']['filename'], self.device)
            else:
                raise NotImplementedError

        test_num_episode = self.tester_params['test_episodes']
        episode = 0
        test_num_episode = self.tester_params.get("max_episodes", test_num_episode)
        print(f"Number of episodes to test: {test_num_episode}")

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            aug_factor = self.tester_params['aug_factor']


            costs, rt, nb_iter, solutions = self._test_one_batch(batch_size, self.tester_params['nb_iterations'], episode,
                                                                  aug_factor=aug_factor)

            costs_AM.update(costs.mean(), batch_size)
            rt_AM.update(rt, batch_size)
            nb_iter_AM.update(nb_iter, batch_size)



            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score: {:.3f}, running_mean: {:.3f} iter: {:.1f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, costs.mean(), costs_AM.avg, nb_iter))

            with open(os.path.join(self.result_folder, "results.csv"), mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), list(costs), [rt / batch_size] * batch_size, [nb_iter] * batch_size))

            with open(os.path.join(self.result_folder, "solutions.csv"), mode='a', newline='') as file:
                tours_batch = [s.getTourList() for s in solutions]
                tours_batch = [[[0, *t, 0] for t in tours_instance] for tours_instance in tours_batch]
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), tours_batch))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" AVG. COSTS: {:.4f} ".format(costs_AM.avg))
                self.logger.info(" AVG. RUNTIME: {:.2f} ".format(rt_AM.avg))
                self.logger.info(" AVG. ITERATIONS: {:.1f} ".format(nb_iter_AM.avg))

    def _test_one_batch(self, batch_size, nb_iterations, episode, aug_factor=1):
        rollout_size = self.tester_params['rollout_size']
        aug_batch_size = batch_size * aug_factor
        use_model = not self.tester_params['use_baseline_destroy']
        max_runtime = self.tester_params['max_runtime']
        softmax_temp = self.tester_params['softmax_temp']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']
        delta = self.tester_params['SA_delta']

        runtime_limited = not max_runtime <= 0


        # Ready
        ###############################################
        for model in self.deconstruction_operator:
            model['model'].eval()

        with torch.inference_mode():
        # with torch.no_grad():

            start_time = time.time()

            self.env.init_instances(batch_size, rollout_size, self.device, aug_factor)

            incumbent_costs = np.full(batch_size, np.inf)
            incumbent_solutions = [None] * batch_size

            T_0 = self.tester_params['SA_start_T']
            T_f = self.tester_params['SA_final_T']
            T = T_0

            if not runtime_limited:
                c = (T_f / T_0) ** (1 / nb_iterations)

            # tqdm progress bar for iterations
            for i in tqdm(range(nb_iterations), total=nb_iterations, desc='SA Iterations', disable=TQDM_DISABLE):

                if MEASURE_TIME:
                    model_infer_start = time.time()
                if use_model:
                    # choose model
                    operator = random.choice(self.deconstruction_operator)
                    model = operator['model']
                    self.env.num_nodes_to_remove = operator['node_to_remove']

                    state = self.env.reset()
                    reset_state = self.env.get_model_input(self.device)

                    # Sample z vectors
                    z_dim = model.model_params['z_dim']
                    z_idx = torch.multinomial((torch.ones(aug_batch_size, 2 ** z_dim) / 2 ** z_dim),
                                              rollout_size, replacement=False)
                    z = operator['binary_vector_pool'][z_idx].reshape(aug_batch_size, 1, rollout_size, z_dim)
                    z = z.transpose(1, 2).reshape(aug_batch_size, rollout_size, z_dim)

                    with torch.amp.autocast(device_type=self.device.type):
                        model.pre_forward(reset_state, z)

                        done = False
                        while not done:
                            selected, _, _ = model(state, softmax_temp)
                            # shape: (batch, pomo)

                            state, done = self.env.step(selected)

                    selected_nodes = self.env.selected_node_list.cpu().numpy()
                    
                    # TEMP: save related information
                    # Save the tours, tours is a list of lists of lists with not the same length
                    tours = self.env.instanceSet.getTours()
                    with open(f"results3/episode{episode}_iter{i}_tours.pkl", "wb") as f:
                        pickle.dump(tours, f)
                    # Save the selected_nodes
                    np.save(f"results3/episode{episode}_iter{i}_selected_nodes.npy", selected_nodes)

                if MEASURE_TIME:
                    model_infer_end = time.time()
                    model_infer_time = model_infer_end - model_infer_start

                if MEASURE_TIME:
                    loop_start = time.time()
                new_solutions = [None] * (batch_size * aug_factor)
                new_costs = np.zeros((batch_size * aug_factor))

                if USE_SISR_PARALLEL:
                    num_processes = self.tester_params.get('num_processes', 8)
                    # solutions_batch = []
                    solutions_batch = self.env.instanceSet._solutions # trick to access the real data
                    A_batch = []
                    for i, b_idx in enumerate(range(batch_size * aug_factor)):
                        # solution = self.env.instanceSet.get_solution(b_idx)
                        # solutions_batch.append(solution)
                        if use_model:
                            A_batch.append(selected_nodes[b_idx])
                        else:
                            sel_nodes = SISRs.heuristic_deconstruction_selection(solutions_batch[i], self.env.num_nodes_to_remove, rollout_size)
                            A_batch.append(sel_nodes)
                    # Wrap A_batch as a list of lists of lists (each A is a list of lists)
                    # If selected_nodes[b_idx] is already a list of lists, this is fine
                    results = SISRs.remove_recreate_allImp_batch(
                        solutions_batch,
                        A_batch,
                        beta,
                        self.env_params['recreate_n'],
                        T,
                        insert_in_new_tours_only,
                        num_processes
                    )
                    for b_idx, (sol, _) in enumerate(results):
                        new_solutions[b_idx] = sol
                        new_costs[b_idx] = sol.totalCosts
                        if sol.totalCosts < incumbent_costs[b_idx % batch_size]:
                            incumbent_costs[b_idx % batch_size] = sol.totalCosts
                            incumbent_solutions[b_idx % batch_size] = sol
                else:
                    for b_idx in range(batch_size * aug_factor):  # 1 batch - one problem instance
                        if use_model:
                            solution = self.env.instanceSet.get_solution(b_idx)
                            sol, _ = SISRs.remove_recreate_allImp(solution, selected_nodes[b_idx],
                                                          beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)
                        else:
                            solution = self.env.instanceSet.get_solution(b_idx)
                            sel_nodes = SISRs.heuristic_deconstruction_selection(solution, self.env.num_nodes_to_remove,
                                                                                  rollout_size)
                            sol, _ = SISRs.remove_recreate_allImp(solution, sel_nodes,
                                                          beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)

                        new_solutions[b_idx] = sol
                        new_costs[b_idx] = sol.totalCosts

                        if sol.totalCosts < incumbent_costs[b_idx % batch_size]:
                            incumbent_costs[b_idx % batch_size] = sol.totalCosts
                            incumbent_solutions[b_idx % batch_size] = sol

                        # breakpoint()
                
                if MEASURE_TIME:
                    loop_end = time.time()
                    loop_time = loop_end - loop_start

                if MEASURE_TIME:
                    print(f"Iteration {i+1}/{nb_iterations}: Model inference time: {model_infer_time:.4f}s, For loop time: {loop_time:.4f}s")

                for b_idx in range(batch_size * aug_factor):
                    self.env.instanceSet.set_solution(b_idx, new_solutions[b_idx])

                if runtime_limited:
                    cur_runtime = time.time() - start_time
                    T = T_f * (T_0 / T_f) ** (1 - (cur_runtime / max_runtime))
                    if cur_runtime > max_runtime:
                        break
                else:
                    T *= c


            if runtime_limited:
                assert time.time() - start_time > max_runtime, "A runtime limit was set, but search terminated based on number of iterations."

            return incumbent_costs, time.time() - start_time, i + 1, incumbent_solutions
