import json
import os
from math import ceil, floor
import numpy as np
import torch
from typing import Tuple, List, Optional

from DataHandling.FileBasedDatasetBase import FileBasedDataset
from Utils import logger
from Utils.Constants import FileNamesConstants, Diff
from SelectionAgent.SelectionAgentDataLoader import SelectionAgentDataloader
from Utils.utils import function_start_save_params


class SelectionAgentModel:
    def __init__(self, agent_folder, lstm_folder, lstm_model_path, lstm_model,
                 initial_cut_size: int, cut_decay: Optional[float], max_steps: int,
                 min_models_steps_for_eval: int, sequence_size_increase_between_cuts: int):
        """
        :param agent_folder:
        :param lstm_folder:
        :param lstm_model_path:
        :param initial_cut_size: the number of models removed in the first iteration
        :param cut_decay: deacy in the cut size after every agent step
        :param max_steps: maximum number of agent stesp
        :param min_models_steps_for_eval: The size of the first sequence used by the agent (minimal number of CNN steps
                    used by selection agent for eval = first number of CNN steps used for eval).
        :param sequence_size_increase_between_cuts:  How many more CNN steps are used by the agent after every cut. e.g., if
                min_models_steps_for_eval=1 and sequence_size_increase_between_cuts=4. In the first agent iteration
                data that will be used is CNN data after first step (100 mini batches), in second agent iteration the
                CNN step will be 5(=4+1) (500 mini batch steps).
        """
        func_run_params = locals().copy()
        func_run_params.pop('self')
        func_run_params.pop('lstm_model')
        function_start_save_params(func_run_params, extra_data=None, config=None,
                                   save_path=os.path.join(agent_folder, FileNamesConstants.MODEL_HYPER_PARAMS))

        self._lstm_model = lstm_model
        self._lstm_model.eval()
        self.initial_cut_size = initial_cut_size
        self.cut_decay = cut_decay
        self.max_steps = max_steps
        self.min_models_steps_for_eval = min_models_steps_for_eval
        self.sequence_size_increase_between_cuts = sequence_size_increase_between_cuts
        self._maxed_out_data = False

    @staticmethod
    def __sort_models_by_pred(curr_preds, models_ids):
        pred_ids = [(pred, curr_id) for pred, curr_id in zip(curr_preds, models_ids)]
        pred_ids = sorted(pred_ids, key=lambda val: val[0])
        return pred_ids

    @staticmethod
    def select_bad_models(curr_preds, models_ids, remove_size):
        pred_ids = SelectionAgentModel.__sort_models_by_pred(curr_preds, models_ids)
        ids_to_remove = [cid for _, cid in pred_ids[:remove_size]]
        return ids_to_remove

    def __calculate_new_cut_size(self, last_cut_size, current_dataset_size):
        """ Calculate a new cut size after each agent iteration """
        cut_size_is_percentile = 0 < last_cut_size < 1

        if self.cut_decay is None:
            new_cut_size = last_cut_size
        else:
            new_cut_size = last_cut_size - last_cut_size * self.cut_decay
            # decide on rounding method for the new cut size in case we cut based on a number and not on percentile,
            # depending on increase/decrease in decay
            if not cut_size_is_percentile:
                increase_cut_size_decay = self.cut_decay < 0
                # floor/ceil are used for each different case to assure that there will be some decay
                if not increase_cut_size_decay:
                    new_cut_size = floor(new_cut_size)
                else:
                    new_cut_size = ceil(new_cut_size)

        if cut_size_is_percentile:
            num_models_to_cut = ceil(new_cut_size * current_dataset_size)
        else:
            num_models_to_cut = new_cut_size

        return num_models_to_cut, new_cut_size

    def _predict_single_epoch(self, dataloader, batch_size):
        step_preds = np.zeros(len(dataloader))
        num_models_ran = 0
        with torch.no_grad():
            for idx, (curr_inputs, _) in enumerate(dataloader):
                if self._lstm_model.is_double():
                    curr_inputs = curr_inputs.double()
                num_models_ran += curr_inputs.shape[0]
                step_preds[idx * batch_size: (idx + 1) * batch_size] = self._lstm_model(
                    curr_inputs).detach().numpy().squeeze()
            if curr_inputs.shape[1] >= Diff.NUMBER_STEPS_SAVED:
                self._maxed_out_data = True
        return step_preds, num_models_ran

    def run_agent(self, dataset: FileBasedDataset, final_size: int, batch_size: int)\
            -> Tuple[List[str], List[float], List[str], np.ndarray, List[int]]:
        """
        :param dataset:
        :param final_size: final amount of models selected by the agent
        :param batch_size:
        :return:
        """
        logger().log('SelectionAgentModel::run_agent', f'Running on dataset size: {len(dataset)} to {final_size=}')
        resources_used = list()
        dataloader = SelectionAgentDataloader(maps_dataset=dataset, batch_size=batch_size,
                                              initial_num_steps_to_use=self.min_models_steps_for_eval,
                                              sequence_size_increase=self.sequence_size_increase_between_cuts)
        steps_done = 1
        last_cut_size = self.initial_cut_size
        num_models_to_cut = last_cut_size if isinstance(last_cut_size, int) else floor(len(dataset)*last_cut_size)
        while steps_done <= self.max_steps and len(dataloader) > final_size:
            if num_models_to_cut == 0:
                logger().log('Got to final size before finished steps')
                break

            step_preds, num_models_ran = self._predict_single_epoch(dataloader, batch_size)
            if len(dataloader) - num_models_to_cut < final_size:
                num_models_to_cut = len(dataloader) - final_size

            if steps_done == 1:
                resources_used.append(num_models_ran * self.min_models_steps_for_eval)
            else:
                resources_used.append(num_models_ran * self.sequence_size_increase_between_cuts)

            ids_to_remove = self.select_bad_models(step_preds, dataloader.current_models_ids, num_models_to_cut)
            logger().log('SelectionAgentModel::run_agent', 'step: ', steps_done, ', current size: ', len(dataloader),
                         ', remove size:', num_models_to_cut, ', ids selected for removing: ', ids_to_remove)
            dataloader.remove_files(removed_ids=ids_to_remove)
            if self._maxed_out_data:
                break
            steps_done += 1
            num_models_to_cut, last_cut_size = self.__calculate_new_cut_size(current_dataset_size=len(dataloader),
                                                                             last_cut_size=last_cut_size)

        dataloader.reset()
        final_ids = dataloader.current_models_ids
        if len(final_ids) > final_size:
            final_preds, _ = self._predict_single_epoch(dataloader, batch_size)
            num_models_to_cut = len(final_ids) - final_size
            ids_to_remove = self.select_bad_models(final_preds, dataloader.current_models_ids, num_models_to_cut)
            logger().log('SelectionAgentModel::run_agent', 'step: size fix, current size: ', len(dataloader),
                         ', remove size:', num_models_to_cut, ', ids selected for removing: ', ids_to_remove)
            dataloader.remove_files(removed_ids=ids_to_remove)
            final_ids = dataloader.current_models_ids
        # Sort final models by predictions
        final_preds, _ = self._predict_single_epoch(dataloader, batch_size)
        final_preds, final_ids = list(zip(*reversed(self.__sort_models_by_pred(final_preds, final_ids))))

        all_models_ids, all_models_results = dataloader.original_models_results
        return final_ids, final_preds, all_models_ids, all_models_results, resources_used
