import logging
import random

import torch
import numpy as np
from logging import getLogger
import copy

from .env import Env
from .model import Model
# from SISRs import ruin

from .logging_utils import *

import itertools
import csv


class Search:
    def __init__(self,
                 env_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='tester')
        self.result_folder = get_result_folder()

        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        global SISRs
        if env_params['problem'] == "cvrp":
            from .cpp.cvrp import SISRs
        elif env_params['problem'] == "vrptw":
            from .cpp.vrptw import SISRs
        elif env_params['problem'] == "pcvrp":
            from .cpp.pcvrp import SISRs
        else:
            raise NotImplementedError

        self.deconstruction_operator = []

        if not self.tester_params['use_baseline_destroy']:
            for model_data in self.tester_params['model_load']:
                checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_data)
                checkpoint = torch.load(checkpoint_fullname, map_location=device)
                model_params = checkpoint['model_params']

                model = Model(**model_params)
                model.load_state_dict(checkpoint['model_state_dict'])
                binary_vector_pool = torch.Tensor(
                    [list(i) for i in itertools.product([0, 1], repeat=model_params['z_dim'])])
                self.deconstruction_operator.append({'model': model, 'binary_vector_pool': binary_vector_pool, **model_data})
                assert checkpoint['env_params']['num_nodes_to_remove'] == model_data['node_to_remove']

        # ENV
        self.env = Env(False, **self.env_params)

        # utility
        self.time_estimator = TimeEstimator()


    def run(self):
        self.time_estimator.reset()

        costs_AM = AverageMeter()
        rt_AM = AverageMeter()
        nb_iter_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            extension = os.path.splitext(self.tester_params['test_data_load']['filename'])[1]
            if extension == ".pkl":
                self.env.load_problem_dataset_pkl(self.tester_params['test_data_load']['filename'],
                                                self.tester_params['test_episodes'])
            elif extension == ".pt":
                self.env.load_problem_dataset_pt(self.tester_params['test_data_load']['filename'], self.device)
            else:
                raise NotImplementedError

        test_num_episode = self.tester_params['test_episodes']
        episode = 0


        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            aug_factor = self.tester_params['aug_factor']


            costs, rt, nb_iter, solutions = self._test_one_batch(batch_size, self.tester_params['nb_iterations'], episode,
                                                                  aug_factor=aug_factor)

            costs_AM.update(costs.mean(), batch_size)
            rt_AM.update(rt, batch_size)
            nb_iter_AM.update(nb_iter, batch_size)



            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score: {:.3f}, running_mean: {:.3f} iter: {:.1f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, costs.mean(), costs_AM.avg, nb_iter))

            with open(os.path.join(self.result_folder, "results.csv"), mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), list(costs), [rt / batch_size] * batch_size, [nb_iter] * batch_size))

            with open(os.path.join(self.result_folder, "solutions.csv"), mode='a', newline='') as file:
                tours_batch = [s.getTourList() for s in solutions]
                tours_batch = [[[0, *t, 0] for t in tours_instance] for tours_instance in tours_batch]
                writer = csv.writer(file)
                writer.writerows(zip(list(range(episode - batch_size + 1, episode+1)), tours_batch))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" AVG. COSTS: {:.4f} ".format(costs_AM.avg))
                self.logger.info(" AVG. RUNTIME: {:.2f} ".format(rt_AM.avg))
                self.logger.info(" AVG. ITERATIONS: {:.1f} ".format(nb_iter_AM.avg))

    def _test_one_batch(self, batch_size, nb_iterations, episode, aug_factor=1):
        rollout_size = self.tester_params['rollout_size']
        aug_batch_size = batch_size * aug_factor
        use_model = not self.tester_params['use_baseline_destroy']
        max_runtime = self.tester_params['max_runtime']
        softmax_temp = self.tester_params['softmax_temp']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']

        runtime_limited = not max_runtime <= 0


        # Ready
        ###############################################
        for model in self.deconstruction_operator:
            model['model'].eval()

        with torch.no_grad():

            start_time = time.time()

            self.env.init_instances(batch_size, rollout_size, self.device, aug_factor)

            incumbent_costs = np.full(batch_size, np.inf)
            incumbent_sols = [None] * batch_size

            T_0 = self.tester_params['SA_start_T']
            T_f = self.tester_params['SA_final_T']
            T = T_0

            reset_size = aug_factor
            if runtime_limited:
                time_to_next_reset = (max_runtime - (time.time() - start_time)) / (self.tester_params["SA_nb_resets"] + 1)
                next_reset_time = time.time() + time_to_next_reset
            else:
                c = (T_f / T_0) ** (1 / nb_iterations)
                SA_reset_frequency = nb_iterations // self.tester_params["SA_nb_resets"]

            reset_counter = 0
            for i in range(nb_iterations):

                if runtime_limited and time.time() - start_time > max_runtime:
                    break

                if use_model:

                    # choose model
                    operator = random.choice(self.deconstruction_operator)
                    model = operator['model']
                    self.env.num_nodes_to_remove = operator['node_to_remove']

                    state = self.env.reset()
                    reset_state = self.env.get_model_input(self.device)

                    # Sample z vectors
                    z_dim = model.model_params['z_dim']
                    z_idx = torch.multinomial((torch.ones(aug_batch_size, 2 ** z_dim) / 2 ** z_dim),
                                              rollout_size, replacement=False)
                    z = operator['binary_vector_pool'][z_idx].reshape(aug_batch_size, 1, rollout_size, z_dim)
                    z = z.transpose(1, 2).reshape(aug_batch_size, rollout_size, z_dim)

                    with torch.amp.autocast(device_type=self.device.type):
                        model.pre_forward(reset_state, z)

                        done = False
                        while not done:
                            selected, _, _ = model(state, softmax_temp)
                            # shape: (batch, pomo)

                            state, done = self.env.step(selected)

                    selected_nodes = self.env.selected_node_list.cpu().numpy()

                for b_idx in range(batch_size * aug_factor):  # 1 batch - one problem instance
                    if use_model:
                        solution = self.env.instanceSet.get_solution(b_idx)
                        sol, _ = SISRs.remove_recreate_allImp(solution, selected_nodes[b_idx],
                                                      beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)
                    else:
                        solution = self.env.instanceSet.get_solution(b_idx)
                        selected_nodes = SISRs.heuristic_deconstruction_selection(solution, self.env.num_nodes_to_remove,
                                                                                  rollout_size)
                        sol, _ = SISRs.remove_recreate_allImp(solution, selected_nodes,
                                                      beta, self.env_params['recreate_n'], T, insert_in_new_tours_only)


                    self.env.instanceSet.set_solution(b_idx, sol)

                    if sol.totalCosts < incumbent_costs[b_idx % batch_size]:
                        incumbent_costs[b_idx % batch_size] = sol.totalCosts
                        incumbent_sols[b_idx % batch_size] =  sol

                do_reset = False
                if runtime_limited:
                    if time.time() > next_reset_time and reset_counter < self.tester_params["SA_nb_resets"]:
                        do_reset = True
                        next_reset_time += time_to_next_reset
                else:
                    do_reset = i % SA_reset_frequency == 0 and i > 0

                if do_reset:  # iter_without_imp[b_idx] > self.tester_params['allowed_iter_without_imp']:

                    # rnd = np.random.randint(1, 2)

                    assert batch_size == 1

                    reset_counter += 1

                    if reset_size > self.tester_params['SA_min_diff_resets']:
                        reset_size = reset_size // 2
                    order = np.argsort(self.env.instanceSet.costs)[:reset_size]
                    order = np.repeat(order, aug_factor // reset_size)
                    old_solutions = copy.copy(self.env.instanceSet._solutions)
                    print(f"RESET TO {reset_size} BEST")
                    print("##########")
                    print("##########")

                    for b_idx in range(batch_size * aug_factor):
                        sol = old_solutions[order[b_idx]]
                        self.env.instanceSet.set_solution(b_idx, sol)


                cost = incumbent_costs
                print(cost)
                print(np.mean(self.env.instanceSet.costs), np.mean(cost))

                if runtime_limited:
                    cur_runtime = time.time() - start_time
                    T = T_f * (T_0 / T_f) ** (1 - (cur_runtime / max_runtime))
                else:
                    T *= c
                print(T)

            if runtime_limited:
                assert time.time() - start_time > max_runtime, "A runtime limit was set, but search terminated based on number of iterations."

            return incumbent_costs, time.time() - start_time, i + 1, incumbent_sols
