import torch

from logging import getLogger
from .env import Env as Env
from .logging_utils import *
import os.path

import itertools
import wandb


class Validator:
    def __init__(self,
                 device,
                 env_params,
                 trainer_params, model_params, logger_params):

        # save arguments
        self.env_params = env_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='validator')
        self.result_folder = get_result_folder()

        self.device = device

        # ENV and MODEL
        self.env = Env(True, **self.env_params)

        # utility
        self.time_estimator = TimeEstimator()
        self.use_wandb = logger_params['wandb']['enable']

        self.binary_string_pool = torch.Tensor(
            [list(i) for i in itertools.product([0, 1], repeat=model_params['z_dim'])])

    def run(self, model, frozen_model, training_epoch):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()
        div_1_AM = AverageMeter()
        div_2_AM = AverageMeter()

        if self.trainer_params['valid_data_load']['enable']:
            extension = os.path.splitext(self.trainer_params['valid_data_load']['filename'])[1]
            if extension == ".pkl":
                self.env.load_problem_dataset_pkl(self.trainer_params['valid_data_load']['filename'],
                                                self.trainer_params['valid_episodes'])
            elif extension == ".pt":
                self.env.load_problem_dataset_pt(self.trainer_params['valid_data_load']['filename'], self.device)
            else:
                raise NotImplementedError

        validate_num_episode = self.trainer_params['valid_episodes']
        episode = 0

        # Augmentation
        ###############################################
        if self.trainer_params['valid_augmentation_enable']:
            aug_factor = self.trainer_params['valid_aug_factor']
        else:
            aug_factor = 1

        self.logger.info(" *** Validation Start *** ")

        logs = np.zeros((0, self.trainer_params['valid_iterations']))
        while episode < validate_num_episode:

            remaining = validate_num_episode - episode
            batch_size = min(self.trainer_params['valid_batch_size'], remaining)

            score, aug_score, logs_episode, div = self._validate_one_batch(model, frozen_model, batch_size,
                                                                           self.trainer_params['valid_iterations'],
                                                                           aug_factor=aug_factor)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)
            div_1_AM.update(div[0], batch_size)
            div_2_AM.update(div[1], batch_size)
            logs = np.append(logs, logs_episode, axis=0)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, validate_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score: {:.3f}, aug_score: {:.3f}".format(
                episode, validate_num_episode, elapsed_time_str, remain_time_str, score_AM.avg, aug_score_AM.avg))


        self.logger.info(" *** Validation Done *** ")
        self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
        self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))
        self.logger.info(" DIVERSITY SCORE: {:.4f} ".format(div_1_AM.avg))
        self.logger.info(" UNIQUE ROLLOUTS: {:.4f} ".format(div_2_AM.avg))

        # write cost-age csv file
        file_path = os.path.join(self.result_folder, "valid_logs", f"epoch-{training_epoch}.csv")
        write_csv(logs, file_path)


        # Additional greedy diversity evaluation (one batch only; Temporary solution)
        org_eval_type = model.model_params['eval_type']
        model.model_params['eval_type'] = "argmax"
        self.env.problem.saved_index = 0
        _, _, _, greedy_div = self._validate_one_batch(model, frozen_model, self.trainer_params['valid_batch_size'],
                                                                       self.trainer_params['valid_iterations'],
                                                                       aug_factor=aug_factor)
        self.logger.info(" GREEDY UNIQUE ROLLOUTS: {:.4f} ".format(greedy_div[1]))
        model.model_params['eval_type'] = org_eval_type

        if self.use_wandb:
            wandb.log(step=training_epoch, data={"val/no_aug_score": score_AM.avg, "val/aug_score": aug_score_AM.avg,
                                                  "val/diversity_score": div_1_AM.avg,
                                                  "val/unique_rollouts": div_2_AM.avg,
                                                  "val/greedy_unique_rollout": greedy_div[1]})

        return aug_score_AM.avg

    def _validate_one_batch(self, model, frozen_model, batch_size, nb_iterations, aug_factor=1):
        rollout_size = self.trainer_params['valid_rollout_size']
        z_dim = model.model_params['z_dim']
        recreate_n = self.env_params['recreate_n']
        beta = self.env_params['beta']
        insert_in_new_tours_only = self.env_params['insert_in_new_tours_only']
        aug_batch_size = batch_size * aug_factor

        logs = np.zeros((batch_size, nb_iterations))

        # Ready
        ###############################################
        model.eval()
        with torch.no_grad():

            self.env.init_instances(batch_size, rollout_size, self.device, aug_factor)

            for i in range(nb_iterations):

                state = self.env.reset()
                reset_state = self.env.get_model_input(self.device)

                # Sample z vectors
                z_idx = torch.multinomial((torch.ones(aug_batch_size, 2 ** z_dim) / 2 ** z_dim),
                                          rollout_size, replacement=False)
                z = self.binary_string_pool[z_idx].reshape(aug_batch_size, 1, rollout_size, z_dim)
                z = z.transpose(1, 2).reshape(aug_batch_size, rollout_size, z_dim)

                with torch.amp.autocast(device_type=self.device.type):
                    model.pre_forward(reset_state, z)

                done = False
                while not done:
                    with torch.amp.autocast(device_type=self.device.type):
                        selected, _, _ = model(state)
                    # shape: (batch, pomo)

                    state, done = self.env.step(selected)

                selected_nodes = self.env.selected_node_list.cpu().numpy()

                # Repair
                self.env.instanceSet.remove_recreate(selected_nodes, recreate_n, "allImp", T=0, beta=beta, insert_in_new_tours_only=insert_in_new_tours_only)

                logs[:, i] = np.array(self.env.instanceSet.costs).reshape(aug_factor, -1).min(axis=0)

            div = self.calculate_diversity(selected_nodes)
            aug_cost = np.array(self.env.instanceSet.costs)
            cost = aug_cost.reshape(aug_factor, -1)
            return np.mean(cost[0]), np.mean(cost.min(axis=0)), logs, div

    def calculate_diversity(self, selected_nodes):
        batch_size = selected_nodes.shape[0]
        rollout_size = selected_nodes.shape[1]

        div_values_1 = []
        div_values_2 = []

        for i in range(batch_size):

            # Score 1
            nb_selected_in_other_rollouts = 0
            for j in range(rollout_size):
                for n in selected_nodes[i, j]:
                    nb_selected_in_other_rollouts += (selected_nodes[i] == n).any(1).sum() - 1

            div = nb_selected_in_other_rollouts / ((rollout_size - 1) * selected_nodes.shape[2] * rollout_size)
            div_values_1.append(div)

            # Score 2 (Unique selections)
            div_values_2.append(np.unique(np.sort(selected_nodes[i], axis=1), axis=0).shape[0] / rollout_size)

        return np.mean(div_values_1), np.mean(div_values_2)
