import abc
import logging
from abc import ABC
from logging import Logger
from typing import List, Dict

import torch
from torch import Tensor
from torch.optim import Adam
from torch.utils.data import RandomSampler, BatchSampler, DataLoader

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerIdentifiableSpace
from algorithms.convergence_algorithms.utils import distance_between_tensors
from algorithms.mapping.base import InputMapping
from algorithms.mapping.trust_region import LinearTrustRegion
from algorithms.nn.datasets import PointDataset
from algorithms.nn.generative_modules import NICE
from algorithms.stopping_condition.base import AlgorithmStopCondition
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable


class GeneticBBO(ConvergenceDrawable, Algorithm, ABC):
    def __init__(
        self,
        env: BoundedEvaluatedSamplerIdentifiableSpace,
        curr_point: Tensor,
        max_dataset_size_multiplier: int = 20,
        input_mapping: InputMapping = None,
        device: int = None,
        logger: Logger = None,
    ):
        self.env = env
        self.curr_best_point = curr_point
        self.max_dataset_size_multiplier = max_dataset_size_multiplier
        self.input_mapping = input_mapping
        self.device = device
        self.logger = logger or logging.getLogger(__name__)

    @property
    def curr_point_to_draw(self):
        return self.curr_best_point

    @property
    def environment(self):
        return self.env

    @property
    def best_point_until_now(self):
        return self.curr_best_point

    def set_start_point(self, start_point: Tensor):
        if self.input_mapping:
            start_point = self.input_mapping.map(self.environment.normalize(start_point))
        self.curr_best_point = start_point

    @abc.abstractmethod
    def generate_samples(self, generation_size: int) -> Tensor:
        raise NotImplementedError()

    @abc.abstractmethod
    def update_model(
        self,
        training_epochs: int,
        batch_size: int,
        samples_dataset: Tensor,
        evaluations_dataset: Tensor,
        new_samples: Tensor,
        evaluations: Tensor,
    ):
        raise NotImplementedError()

    def evaluate_samples(self, samples: Tensor) -> Tensor:
        if self.input_mapping is not None:
            samples = self.environment.denormalize(self.input_mapping.inverse(samples))
        return self.environment(samples)

    def train(
        self,
        epochs: int,
        generation_size: int,
        training_epochs: int = 5,
        batch_size=32,
        num_of_epoch_with_no_improvement: int = None,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        stopping_conditions = stopping_conditions or []
        callback_handlers = callback_handlers or []
        self.logger.info(f"starting algorithm {self.ALGORITHM_NAME} for space {self.env}")
        no_improvement_counter = 0
        best_value = self.evaluate_samples(self.curr_best_point)[0]
        samples_dataset = torch.tensor([], device=self.device)
        evaluations_dataset = torch.tensor([], device=self.device)

        for _ in range(epochs):
            new_samples = self.generate_samples(generation_size)
            evaluations = self.evaluate_samples(new_samples)
            new_best_idx = evaluations.argmax()
            samples_dataset = torch.cat((new_samples, samples_dataset))[
                : generation_size * self.max_dataset_size_multiplier
            ]
            evaluations_dataset = torch.cat((evaluations, evaluations_dataset))[
                : generation_size * self.max_dataset_size_multiplier
            ]
            self.update_model(
                training_epochs,
                batch_size,
                samples_dataset,
                evaluations_dataset,
                new_samples,
                evaluations,
            )

            progress_value = evaluations[new_best_idx] - best_value
            if progress_value.item() >= 0:
                self.logger.info(
                    f"No improvement {progress_value} "
                    f"counter {no_improvement_counter}/{num_of_epoch_with_no_improvement}"
                )
                no_improvement_counter += 1
            else:
                best_value = evaluations[new_best_idx]
                self.logger.info(
                    f"moved {distance_between_tensors(self.curr_best_point, new_samples[new_best_idx])} "
                    f"toward best point. progress {progress_value}. Now - {best_value}"
                )
                self.curr_best_point = new_samples[new_best_idx]
                no_improvement_counter = 0
            for handler in callback_handlers:
                handler.on_epoch_end(self, database=new_samples)
            for stop_condition in stopping_conditions:
                if stop_condition.should_stop(self, counter=no_improvement_counter):
                    raise AlgorithmFinish(
                        stop_condition.REASON.format(
                            alg=self.ALGORITHM_NAME,
                            env=self.env,
                            best_point=self.best_point_until_now,
                            tr=self.input_mapping,
                        )
                    )

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"input_mapping": LinearTrustRegion}


class NiceBBO(GeneticBBO):
    ALGORITHM_NAME = "gen_nice"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = NICE(
            "gaussian",
            4,
            "affine",
            self.env.dimension,
            4,
            int(self.env.dimension * 1.5),
            self.device,
        )
        self.model.to(device=self.device)
        self.optimizer = Adam(self.model.parameters())

    def generate_samples(self, generation_size: int):
        z_samples = self.model.sample(generation_size)
        return self.model.f_inverse(z_samples)

    def update_model(
        self,
        training_epochs: int,
        batch_size: int,
        samples_dataset: Tensor,
        evaluations_dataset: Tensor,
        new_samples: Tensor,
        evaluations: Tensor,
    ):
        total_loss = 0
        for i in range(training_epochs):
            dataset = PointDataset(samples_dataset, evaluations_dataset)
            sampler = BatchSampler(
                RandomSampler(range(len(dataset))), batch_size=batch_size, drop_last=False
            )
            data_loader = DataLoader(dataset, sampler=sampler)

            for batch_samples, batch_evaluations, weights in data_loader:
                loss = self.model(batch_samples[0])
                loss = (loss * weights[0]).mean()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
        self.logger.info(f"Accumulated loss - {total_loss}")
