
import torch

import os
from logging import getLogger

from TSPEnv import TSPEnv as Env
from TSPModel import TSPModel as Model

from utils.utils import *
import utils.transactionutils as utils

class TSPTester:
    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()


        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.model = Model(device,self.model_params)
        self.env = Env(device,self.env_params)

        # Restore
        model_load = tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            score, aug_score = self._test_one_batch(batch_size)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
                self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

    def _test_one_batch(self, batch_size):

        aug_factor=1
        # Ready
        ###############################################
        self.model.eval()
        with torch.no_grad():
            self.env.load_problems(batch_size, aug_factor)
            reset_state, _, _ = self.env.reset()
            batch_nodes = reset_state.problems
            batch_tunnel = reset_state.tunnels
            B,N,_ = batch_nodes.shape
            batch_tunnel_env = utils.expand_all_as_tunnels(N,batch_tunnel)
            batch_tunnel = utils.expand_every_tunnels(N,batch_tunnel)

            batch_nodes = batch_nodes.to(self.device)
            batch_tunnel = batch_tunnel.to(self.device)
            batch_tunnel_env = batch_tunnel_env.to(self.device)
            #print(batch_nodes[0])
            #print(batch_tunnel.shape,batch_tunnel_env.shape)
            batch_coord_tunnel = utils.generate_coord_from_indexes(batch_nodes,batch_tunnel_env)
            #print(batch_coord_nodes[0,0])
            if self.model_params['data_augment']:
                #batch = utils.augment_xy_data_by_8_fold(batch,training=True)#input_dim=16
                #batch_nodes = utils.data_augment(batch_nodes) #input-dim=24
                batch_nodess = utils.augment_xy_data_by_8_fold_POMO(batch_nodes,training=True)#input_dim=32
                batch_coord_tunnel = utils.augment_tunnel_data_by_8_fold(batch_coord_tunnel,training=True)
            embeddings_nodes = self.model.encoder_nodes(batch_nodess)  
            embeddings_tunnels = self.model.encoder_tunnels(batch_coord_tunnel)

            # CONSIDER DIRECTION, ELSE, BATCH_COORD_TUNNEL=(..,batchtunnel_env)
            _,L,_ = embeddings_nodes.shape
            tunnel_table = utils.create_output_matrix_with_batch(batch_tunnel_env,L)
            tunnel_table = torch.tensor(tunnel_table).to(self.device).double()
            #self.logit_k_tunnels = self.adapt_logit_k_tunnel(self.embeddings_tunnels.transpose(1,2))
            logit_k_tunnels = torch.bmm(embeddings_tunnels.transpose(1,2).double(),tunnel_table).float()
            self.model.pre_forward(embeddings_nodes,logit_k_tunnels.mT,batch_tunnel_env.to(self.device))

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, _ = self.model(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)

        # Return
        ###############################################
        aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
        # shape: (augmentation, batch, pomo)

        max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
        # shape: (augmentation, batch)
        no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value

        max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation
        # shape: (batch,)
        aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive value

        return no_aug_score.item(), aug_score.item()
