import abc
from collections import OrderedDict

from typing import Iterable
from torch import nn as nn
import wandb

from lfrl.core.rl_algorithms.batch.batch_rl_algorithm import BatchRLAlgorithm
from lfrl.core.rl_algorithms.offline.offline_rl_algorithm import OfflineRLAlgorithm
from lfrl.core.rl_algorithms.offline.mb_offline_rl_algorithm import OfflineMBRLAlgorithm
from lfrl.core.rl_algorithms.online.online_rl_algorithm import OnlineRLAlgorithm
from lfrl.core.rl_algorithms.online.mbrl_algorithm import MBRLAlgorithm
from lfrl.trainers.trainer import Trainer
from lfrl.torch.pytorch_util import np_to_pytorch_batch


"""
TODO: all of this code should be gutted and interface between RLAlgorithm and
training from batch should be completely redone...
"""


class TorchOnlineRLAlgorithm(OnlineRLAlgorithm):
    def configure_logging(self):
        for net in set(self.trainer.networks):
            wandb.watch(net)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchBatchRLAlgorithm(BatchRLAlgorithm):
    def configure_logging(self):
        for net in set(self.trainer.networks):
            wandb.watch(net)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchMBRLAlgorithm(MBRLAlgorithm):
    def configure_logging(self):
        for net in set(self.trainer.networks + self.model_trainer.networks):
            wandb.watch(net)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)
        for net in self.model_trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)
        for net in self.model_trainer.networks:
            net.train(mode)


class TorchOfflineRLAlgorithm(OfflineRLAlgorithm):
    def configure_logging(self):
        for net in set(self.trainer.networks):
            wandb.watch(net)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)


class TorchOfflineMBRLAlgorithm(OfflineMBRLAlgorithm):
    def configure_logging(self):
        for net in set(self.trainer.networks + self.model_trainer.networks):
            wandb.watch(net)

    def to(self, device):
        for net in self.trainer.networks:
            net.to(device)
        for net in self.model_trainer.networks:
            net.to(device)

    def training_mode(self, mode):
        for net in self.trainer.networks:
            net.train(mode)
        for net in self.model_trainer.networks:
            net.train(mode)


class TorchTrainer(Trainer, metaclass=abc.ABCMeta):
    def __init__(self):
        self._num_train_steps = 0

    def train(self, np_batch):
        self._num_train_steps += 1
        batch = np_to_pytorch_batch(np_batch)
        self.train_from_torch(batch)

    def get_diagnostics(self):
        return OrderedDict([
            ('num train calls', self._num_train_steps),
        ])

    def train_from_torch(self, batch):
        pass

    @property
    def networks(self) -> Iterable[nn.Module]:
        pass
