import cProfile
import dataclasses
import io
import pstats
import signal
import sys
import time
import traceback
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from pstats import SortKey
from typing import Any, Callable

import torch

from ..classes.model import ModelInterface
from ..datasets.classes import Minibatch, MinibatchIterator
from .events import Events


def _handle_pdb(sig, frame):
    import pdb
    pdb.Pdb().set_trace(frame)


class LossFunctionInterface(ABC):
    def __call__(self, model: ModelInterface, data: Minibatch):
        raise NotImplementedError()


@dataclass(frozen=True)
class LossFunctionOutput:
    loss: torch.Tensor
    likelihood: torch.Tensor
    kld: torch.Tensor
    normals_term: torch.Tensor
    eikonal_term: torch.Tensor


@dataclass
class MovingAverage:
    metrics: dict = None
    alpha: float = 0.01

    def update(self, lossfunc_output: LossFunctionOutput):
        if self.metrics is None:
            self.metrics = defaultdict(float)
            for field in dataclasses.fields(lossfunc_output):
                self.metrics[field.name] = float(
                    getattr(lossfunc_output, field.name))
        else:
            for field in dataclasses.fields(lossfunc_output):
                self.metrics[field.name] = (
                    self.alpha * float(getattr(lossfunc_output, field.name)) +
                    (1 - self.alpha) * self.metrics[field.name])

    def state_dict(self):
        return {
            "metrics": self.metrics,
            "alpha": self.alpha,
        }

    def load_state_dict(self, state_dict: metrics):
        self.metrics = state_dict["metrics"]
        self.alpha = state_dict["alpha"]


@dataclass
class TrainerState:
    iteration: int = 0
    max_epochs: int = 0
    last_epoch: int = 0
    epoch: int = 1
    elapsed_seconds: int = 0
    data_load_time: int = 0
    forward_time: int = 0
    backward_time: int = 0
    num_gradient_updates: int = 0
    last_lossfunc_output: LossFunctionOutput = dataclasses.field(default=None,
                                                                 metadata=None)
    moving_average = MovingAverage()

    def state_dict(self):
        return {
            "iteration": self.iteration,
            "epoch": self.epoch,
            "elapsed_seconds": self.elapsed_seconds,
            "num_gradient_updates": self.num_gradient_updates,
            "moving_average": self.moving_average.state_dict()
        }

    def load_state_dict(self, state_dict: dict):
        self.iteration = state_dict["iteration"]
        self.last_epoch = state_dict["epoch"]
        self.elapsed_seconds = state_dict["elapsed_seconds"]
        self.num_gradient_updates = state_dict["num_gradient_updates"]
        self.moving_average.load_state_dict(state_dict["moving_average"])
        self.epoch = self.last_epoch + 1


class TrainerInterface(ABC):
    def __init__(self,
                 model: torch.nn.Module,
                 optimizer: torch.optim.Optimizer,
                 loss_function: Callable[[torch.nn.
                                          Module, Any], LossFunctionOutput],
                 initial_state: TrainerState = None,
                 max_minibatch_iterations: int = None,
                 debug_mode=False):
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.event_handlers = defaultdict(list)
        self.max_minibatch_iterations = max_minibatch_iterations
        self.debug_mode = debug_mode
        if debug_mode:
            torch.autograd.set_detect_anomaly(True)
        if initial_state is None:
            self.state = TrainerState()
        else:
            self.state = initial_state

        signal.signal(signal.SIGUSR1, _handle_pdb)

    def on(self, event_name, *args, **kwargs):
        def decorator(handler):
            self.add_event_handler(event_name, handler, *args, **kwargs)
            return handler

        return decorator

    def add_event_handler(self, event_name, handler, *args, **kwargs):
        self.event_handlers[event_name].append((handler, args, kwargs))

    @abstractmethod
    def run(self):
        raise NotImplementedError()

    def exec_event_handlers(self, event_name):
        handlers = self.event_handlers[event_name]
        for (handler, args, kwargs) in handlers:
            handler(self)

    def state_dict(self):
        state_dict = {}
        state_dict["optimizer"] = self.optimizer.state_dict()
        state_dict["model"] = self.model.state_dict()
        state_dict["trainer"] = self.state.state_dict()
        return state_dict

    def load_state_dict(self, state_dict: dict):
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.model.load_state_dict(state_dict["model"])
        self.state.load_state_dict(state_dict["trainer"])


class Trainer(TrainerInterface):
    def run(self, minibatch_iterator: MinibatchIterator, max_epochs: int):
        self.state.max_epochs = max_epochs
        self.exec_event_handlers(Events.TRAINING_STARTED)
        prev_elapsed_seconds = self.state.elapsed_seconds
        training_start_time = time.time()
        max_minibatch_iterations = (len(minibatch_iterator)
                                    if self.max_minibatch_iterations is None
                                    else self.max_minibatch_iterations)

        for epoch in range(self.state.last_epoch + 1, max_epochs + 1):
            self.state.epoch = epoch
            self.state.last_epoch = epoch - 1
            self.exec_event_handlers(Events.EPOCH_STARTED)

            # pr = cProfile.Profile()
            # pr.enable()

            iteration_start_time = time.time()
            for iteration, data in enumerate(minibatch_iterator, start=1):
                self.state.data_load_time = time.time() - iteration_start_time

                self.state.iteration = iteration
                self.exec_event_handlers(Events.ITERATION_STARTED)

                try:
                    forward_start_time = time.time()
                    lossfunc_output = self.loss_function(self.model, data)
                    self.state.forward_time = time.time() - forward_start_time

                    backward_start_time = time.time()
                    self.optimizer.zero_grad()
                    lossfunc_output.loss.backward()
                    self.optimizer.step()
                    self.state.backward_time = time.time(
                    ) - backward_start_time

                    self.state.last_lossfunc_output = lossfunc_output
                    self.state.moving_average.update(lossfunc_output)
                    self.state.num_gradient_updates += 1
                    self.state.elapsed_seconds = time.time(
                    ) - training_start_time + prev_elapsed_seconds
                    self.exec_event_handlers(Events.ITERATION_COMPLETED)
                except RuntimeError as e:
                    if self.debug_mode:
                        raise e
                    else:
                        tb = sys.exc_info()[2]
                        print("RuntimeError: {}".format(e.with_traceback(tb)))
                        traceback.print_tb(tb)

                iteration_start_time = time.time()

                # pr.disable()
                # s = io.StringIO()
                # sortby = SortKey.TIME
                # ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
                # ps.print_stats(10)
                # print(s.getvalue())
                # exit()

                if self.debug_mode:
                    break

                if iteration >= max_minibatch_iterations:
                    break

            self.exec_event_handlers(Events.EPOCH_COMPLETED)

        self.exec_event_handlers(Events.TRAINING_COMPLETED)
