import abc
import copy
import logging
import math
from logging import Logger
from typing import List, Dict, Callable, Any

import torch
from torch import Tensor
from torch.nn import Sequential, Linear
from torch.optim import Optimizer, Adam
from tqdm.auto import trange

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.basic_config import FuncConfig
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import (
    BoundedEvaluatedSamplerIdentifiableSpace,
)
from algorithms.convergence_algorithms.utils import (
    ball_perturb,
    reset_all_weights,
    distance_between_tensors,
)
from algorithms.mapping.base import InputMapping, DefaultMapping, OutputMapping
from algorithms.mapping.trust_region import TanhTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.modules import (
    ConstructableModel,
    BaseSequentialModel,
    ConfigurableModule,
    BasicNetwork,
)
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.base import AlgorithmStopCondition
from algorithms.stopping_condition.trsut_region import StopAfterXTimes
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable


class ConvergenceAlgorithm(ConvergenceDrawable, Algorithm):
    def __init__(
        self,
        env: BoundedEvaluatedSamplerIdentifiableSpace,
        helper_network: ConfigurableModule,
        model_to_train: ConstructableModel,
        value_optimizer: Optimizer,
        model_to_train_optimizer: Optimizer,
        epsilon: float,
        epsilon_factor: float,
        min_epsilon: float,
        perturb: float,
        max_batch_size: int = 1024,
        num_of_batch_reply: int = 32,
        maximum_movement_for_shrink: float = math.inf,
        output_mapping: OutputMapping = None,
        input_mapping: InputMapping = None,
        dtype: torch.dtype = torch.float64,
        device: int = None,
        logger: Logger = logging.getLogger(__name__),
    ):
        self.env = env
        self.helper_network = helper_network
        self.model_to_train = model_to_train
        self.best_model = copy.deepcopy(model_to_train)
        self.helper_optimizer = value_optimizer
        self.model_to_train_optimizer = model_to_train_optimizer
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_factor = epsilon_factor
        self.perturb = perturb
        self.max_batch_size = max_batch_size
        self.num_of_batch_reply = num_of_batch_reply
        self.maximum_movement_for_shrink = maximum_movement_for_shrink
        self.output_mapping = output_mapping or DefaultMapping()
        self.input_mapping = input_mapping
        self.dtype = dtype
        self.device = device
        self.logger = logger

    @property
    def best_point_until_now(self):
        return self.best_model.model_parameter_tensor().detach()

    @property
    def environment(self):
        return self.env

    @property
    def curr_point_to_draw(self):
        return self.model_to_train.model_parameter_tensor().detach()

    def real_data(self, data: Tensor):
        if self.input_mapping:
            data = self.env.denormalize(self.input_mapping.inverse(data))
        return data

    def eval_data(self, data: Tensor, debug: bool = False):
        data = self.real_data(data)
        return self.env(data, debug_mode=debug)

    def train(
        self,
        epochs: int,
        exploration_size: int,
        num_loop_without_improvement: int,
        min_iteration_before_shrink: int,
        max_num_of_shrink: int = None,
        helper_model_training_epochs: int = 60,
        warmup_minibatch: int = 5,
        warmup_loops: int = 6,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
        **kwargs,
    ):
        self.logger.info(
            f"Starting running {self.__class__.__name__} for {epochs} epochs"
        )
        stopping_conditions = stopping_conditions or []
        callback_handlers = callback_handlers or []

        # stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
        if max_num_of_shrink:
            stopping_conditions.append(StopAfterXTimes(max_num_of_shrink))

        for callback_handler in callback_handlers:
            callback_handler.on_algorithm_start(self)

        self.training_start_hook(
            epochs,
            exploration_size,
            num_loop_without_improvement,
            min_iteration_before_shrink,
            max_num_of_shrink,
            helper_model_training_epochs,
            warmup_minibatch,
            callback_handlers,
            **kwargs,
        )

        # to prevent error if database is not assigned before exception
        database = torch.tensor([])

        try:
            database, evaluations = self.explore(warmup_minibatch * exploration_size)
            num_of_samples = database.shape[-2]
            batch_size = min(self.max_batch_size, num_of_samples)
            minibatches = num_of_samples // batch_size

            self.warm_up(
                batch_size,
                database,
                evaluations,
                exploration_size,
                helper_model_training_epochs,
                minibatches,
                warmup_loops,
            )

            best_model_value = self.eval_data(
                self.best_model.model_parameter_tensor().detach()
            )
            reply_memory_size = self.num_of_batch_reply * exploration_size
            no_improvement_in_model_count = 0
            counter = 0
            num_of_shrinks = 0
            last_tr_unreal_best = self.best_point_until_now.clone()
            self.logger.info(f"starting on {best_model_value}")

            for _ in trange(epochs, desc=f"Training EGL {epochs} epochs"):
                counter += 1
                # Explore
                samples, new_evaluations = self.explore(exploration_size)
                database = torch.cat((database, samples), dim=(len(samples.shape) - 2))[
                    ..., -reply_memory_size:, :
                ]
                evaluations = torch.cat(
                    (evaluations, new_evaluations), dim=(len(evaluations.shape) - 1)
                )[..., -reply_memory_size:]

                better = (evaluations < best_model_value).sum()
                total = len(evaluations)
                better_percentage = better / total
                self.logger.info(
                    f"{better_percentage} percent are better than best, {better} / {total}"
                )

                num_of_samples = database.shape[-2]
                batch_size = min(self.max_batch_size, num_of_samples)
                minibatches = num_of_samples // batch_size
                test_loss = self.train_helper_model(
                    database,
                    evaluations,
                    minibatches,
                    batch_size,
                    exploration_size,
                    helper_model_training_epochs,
                    exploration_size,
                )
                self.train_model()

                # Handle end of epoch
                for handler in callback_handlers:
                    handler.on_epoch_end(
                        self,
                        database=database,
                        test_losses=test_loss,
                        best_model_value=best_model_value,
                    )

                # Check improvement
                real_model = self.real_data(
                    self.model_to_train.model_parameter_tensor().detach()
                )
                new_model_evaluation = self.env(real_model.detach())
                if best_model_value > new_model_evaluation:
                    self.logger.info(
                        f"Improved best known point to {best_model_value} From {new_model_evaluation} In {self.env}"
                    )
                    self.best_model = copy.deepcopy(self.model_to_train)
                    best_model_value = new_model_evaluation
                else:
                    self.logger.warning(
                        f"No improvement ({no_improvement_in_model_count}) for value {new_model_evaluation} "
                        f"in {self.env}"
                    )
                    no_improvement_in_model_count += 1
                    # if better:
                    #     idx = evaluations.argmin()
                    #     self.logger.info(f"attach to a better sample {evaluations[idx]}")
                    #     self.model_to_train = self.model_to_train.from_parameter_tensor(
                    #         database[idx]
                    #     )
                    #     self.best_model = copy.deepcopy(self.model_to_train)

                if (
                    no_improvement_in_model_count >= num_loop_without_improvement
                    and counter >= min_iteration_before_shrink
                ):
                    counter = 0
                    no_improvement_in_model_count = 0
                    num_of_shrinks += 1

                    unreal_distance_between_bests = distance_between_tensors(
                        last_tr_unreal_best,
                        self.best_model.model_parameter_tensor().detach(),
                    )
                    self.epsilon *= self.epsilon_factor
                    self.epsilon = max(self.epsilon, self.min_epsilon)
                    should_shrink_in_addition_to_move = (
                        self.maximum_movement_for_shrink > unreal_distance_between_bests
                    )
                    if self.input_mapping:
                        self.before_shrinking_hook()
                        best_parameters_real = self.input_mapping.inverse(
                            self.best_model.model_parameter_tensor().detach()
                        )
                        real_database = self.input_mapping.inverse(database.detach())
                        if should_shrink_in_addition_to_move:
                            self.logger.info(
                                f"Shrinking trust region, movement {unreal_distance_between_bests}, new center is "
                                f"{self.env.denormalize(best_parameters_real).tolist()} with {self.input_mapping}"
                            )

                            self.input_mapping.squeeze(
                                best_parameters_real,
                                gradient=self.curr_gradient(),
                                grad_net=self.gradient,
                                epsilon=self.epsilon,
                            )
                        else:
                            self.input_mapping.move_center(best_parameters_real)
                            self.logger.info(
                                f"Moving trust region after moving {unreal_distance_between_bests}, new center is "
                                f"{self.env.denormalize(best_parameters_real).tolist()} with {self.input_mapping}"
                            )

                        reply_memory_size = self.num_of_batch_reply * exploration_size
                        self.model_to_train = self.model_to_train.from_parameter_tensor(
                            self.input_mapping.map(best_parameters_real).clone()
                        )
                        self.best_model = copy.deepcopy(self.model_to_train)
                        self.after_shrinking_hook()
                        # NOTE - I reset this network only if trust region has changed
                        #       Because If it has not the network should look the same
                        reset_all_weights(self.helper_network)
                        database = self.input_mapping.map(real_database)
                        self.warm_up(
                            batch_size,
                            database,
                            evaluations,
                            exploration_size,
                            helper_model_training_epochs,
                            minibatches,
                            warmup_loops,
                        )
                    self.logger.info(f"Shrinking sample radius to {self.epsilon}")
                    self.logger.info(f"Space status {self.env}")
                    for handler in callback_handlers:
                        handler.on_algorithm_update(
                            self, database=database, best_model_value=best_model_value
                        )

                for stop_condition in stopping_conditions:
                    if stop_condition.should_stop(self, counter=num_of_shrinks):
                        raise AlgorithmFinish(
                            stop_condition.REASON.format(
                                alg=self.__class__.__name__,
                                env=self.env,
                                best_point=self.best_point_until_now,
                                tr=self.input_mapping,
                            )
                        )
        except NoMoreBudgetError as e:
            self.logger.warning("No more Budget", exc_info=e)
        except AlgorithmFinish as e:
            self.logger.info(f"{self.__class__.__name__} Finish stopped {e}")

        for handler in callback_handlers:
            self.logger.info(
                f"Calling upon {handler.on_algorithm_end} finishing convergence"
            )
            if database.numel() != 0:
                handler.on_algorithm_end(self, database=database)
            else:
                handler.on_algorithm_end(self)

    def warm_up(
        self,
        batch_size,
        database,
        evaluations,
        exploration_size,
        helper_model_training_epochs,
        minibatches,
        warmup_loops,
    ):
        for i in range(warmup_loops):
            self.logger.info(f"{i} loop for warmup for {self.__class__.__name__}")
            self.train_helper_model(
                database,
                evaluations,
                minibatches,
                batch_size,
                exploration_size,
                helper_model_training_epochs,
                len(database),
            )

    def explore(self, exploration_size: int):
        current_model_parameters = self.model_to_train.model_parameter_tensor()
        new_model_samples = self.samples_points(
            current_model_parameters, exploration_size
        )

        # Evaluate
        evaluations = self.evaluate_point(new_model_samples).to(device=self.device)

        self.output_mapping.adapt(evaluations)
        return new_model_samples, evaluations

    def samples_points(self, base_point: Tensor, exploration_size: int):
        self.logger.info(
            f"Exploring new data points. Sampling {exploration_size} points"
        )
        new_model_samples = torch.cat(
            (
                # TODO - use DI to allow other exploration
                ball_perturb(
                    base_point,
                    self.epsilon,
                    exploration_size - 1,
                    self.dtype,
                    self.device,
                ),
                base_point.reshape(1, -1),
            )
        )
        return new_model_samples

    def evaluate_point(self, new_model_samples: Tensor):
        self.logger.info(
            f"Evaluating {len(new_model_samples)} on env with {self.input_mapping}"
        )
        return self.eval_data(new_model_samples)

    def train_helper_model(
        self,
        samples: Tensor,
        samples_value: Tensor,
        num_of_minibatch: int,
        batch_size: int,
        exploration_size: int,
        epochs: int,
        new_samples_count: int,
    ):
        raise NotImplementedError()

    def train_model(self):
        raise NotImplementedError()

    def training_start_hook(self, *args, **kwargs):
        pass

    def before_shrinking_hook(self):
        pass

    def after_shrinking_hook(self):
        pass

    def set_start_point(self, start_point: Tensor):
        if self.input_mapping:
            start_point = self.input_mapping.map(self.env.normalize(start_point))
        self.model_to_train = self.model_to_train.from_parameter_tensor(start_point)
        self.best_model = copy.deepcopy(self.model_to_train)

    @abc.abstractmethod
    def gradient(self, x) -> Tensor:
        raise NotImplementedError()

    def curr_gradient(self) -> Tensor:
        return self.gradient(self.model_to_train.model_parameter_tensor())

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "epsilon": 0.1,
            "epsilon_factor": 0.97,
            "min_epsilon": 1e-4,
            "perturb": 0,
            "exploration_size": 64,
            "epochs": 150_000,
            "num_loop_without_improvement": 10,
            "min_iteration_before_shrink": 40,
            "max_num_of_shrink": math.inf,
            "model_to_train": FuncConfig(
                lambda device, dims, **kwargs: BaseSequentialModel(
                    Sequential(Linear(dims, 1, bias=False, dtype=torch.float64)).to(
                        device=device
                    )
                )
            ),
            "value_optimizer": FuncConfig(
                lambda helper_network, value_lr=1e-3, **kwargs: Adam(
                    helper_network.parameters(),
                    lr=value_lr,
                    betas=(0.9, 0.999),
                    eps=1e-04,
                )
            ),
            "model_to_train_optimizer": FuncConfig(
                lambda model_to_train, model_lr=0.01, **kwargs: Adam(
                    model_to_train.parameters(),
                    lr=model_lr,
                    betas=(0.9, 0.999),
                    eps=1e-04,
                )
            ),
        }

    @classmethod
    def _manipulate_parameters(cls) -> Dict[str, Callable]:
        return {
            "epsilon": FuncConfig(
                lambda epsilon, dims, **kwargs: epsilon * math.sqrt(dims)
            ),
            "maximum_movement_for_shrink": FuncConfig(
                lambda maximum_movement_for_shrink, dims, **kwargs: maximum_movement_for_shrink
                * math.sqrt(dims)
            ),
        }

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {
            "input_mapping": TanhTrustRegion,
            "output_mapping": AdaptedOutputUnconstrainedMapping,
            "helper_network": BasicNetwork,
        }

    @classmethod
    def _additional_configs(cls) -> Dict[str, Dict[str, Any]]:
        return {
            "v2": {
                "exploration_size": 8,
                "maximum_movement_for_shrink": 2,
                "min_trust_region_size": 0,
            },
            "v3": {
                "exploration_size": 8,
                "maximum_movement_for_shrink": 0.2,
                "min_trust_region_size": 0,
            },
            "faster": {
                "model_to_train_optimizer": FuncConfig(
                    lambda model_to_train, **kwargs: Adam(
                        model_to_train.parameters(),
                        lr=0.05,
                        betas=(0.9, 0.999),
                        eps=1e-04,
                    )
                ),
            },
            "sample_norm": {
                "exploration_size": FuncConfig(
                    lambda dims, sample_base=8, **kwargs: int(sample_base)
                    * (math.ceil(math.sqrt(dims)))
                ),
            },
            "sample_log": {
                "exploration_size": FuncConfig(
                    lambda dims, sample_base=8, **kwargs: int(sample_base)
                    * max(1, math.floor(math.log(dims) ** 2))
                ),
            },
            "sample_norm_2": {
                "exploration_size": FuncConfig(
                    lambda dims, **kwargs: 16 * (math.ceil(math.sqrt(dims)))
                ),
            },
        }
