import dataclasses
import json
import logging
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from sklearn import metrics
from torch import Tensor
from torch.utils.tensorboard.writer import SummaryWriter

from pkg.data import BaseData
from pkg.logic.state import State
from pkg.model import BaseModel
from pkg.utils.logging import (
    BEST_FILE_NAME,
    RESULTS_FILE_NAME,
    Timer,
    delete_old_checkpoints,
    get_checkpoint_path_without_suffix,
)

logger: logging.Logger = logging.getLogger(__name__)


class Trainer:
    def __init__(
        self,
        data: BaseData,
        model: BaseModel,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        num_epochs: int,
        early_stopping_tolerance: int,
        gradnorm_clipping_value: float,
        log_after_num_iterations: int,
        eval_modes: list[str],
        dry_run: bool,
    ) -> None:
        logger.info("Init trainer ...")
        self.writer = SummaryWriter(log_dir=".tb")

        self.data = data
        self.model = model
        self.optimizer = optimizer
        self.device = device

        self.num_epochs = num_epochs
        self.early_stopping_tolerance = early_stopping_tolerance
        self.gradnorm_clipping_value = gradnorm_clipping_value
        self.log_after_num_iterations = log_after_num_iterations

        self.dry_run = dry_run
        self.state = State()

        if dry_run:
            return

        # default eval modes
        self.eval_modes = eval_modes

        # logging
        self.trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )
        logger.info(f"Model:\n{self.model}\nTrainable params: {self.trainable_params}")

    def run(self) -> float:
        logger.info("Run ...")

        if self.dry_run:
            return 0.0

        result = dict()
        for mode in self.eval_modes:
            result.update(self.eval(mode=mode))
        self.log_results(result=result, save_checkpoint=True)

        for _ in range(1, self.num_epochs + 1):
            self.train()
            result = dict()

            for mode in self.eval_modes:
                result.update(self.eval(mode=mode))
            self.log_results(result, save_checkpoint=True)

            if self.state.num_epochs_not_improved >= self.early_stopping_tolerance:
                logger.info("Early stopping criterion reached.")
                break

        return self.state.best_val_metric

    def train(self) -> None:
        self.model.train()
        self.state.epoch += 1
        timer = Timer()

        losses = []
        epoch_iteration = 0
        batch: list[Tensor]
        for batch in self.data.train_loader:
            epoch_iteration += 1
            self.state.iteration += 1

            data = batch[0].to(self.device)
            mask = ~(data == -1).all(dim=-1)

            _, loss, *_ = self.model.forward(data=data, padding_mask=mask)

            self.optimizer.zero_grad()
            loss.backward()
            if self.gradnorm_clipping_value is not None:
                grad_norm_before_clipping = nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.gradnorm_clipping_value,
                )
            self.optimizer.step()

            if (
                epoch_iteration % self.log_after_num_iterations == 0
                or epoch_iteration == len(self.data.train_loader)
            ):
                result = {
                    "train_step_time": timer.get_duration_in_seconds(),
                    "train_step_loss": loss.item(),
                }
                if self.gradnorm_clipping_value is not None:
                    result.update(grad_norm=grad_norm_before_clipping.item())
                self.log_results(result=result, save_checkpoint=False)

            losses.append(loss.detach().cpu().numpy())

    def eval(self, mode: str = "val") -> dict:
        self.model.eval()
        timer = Timer()

        with torch.no_grad():
            y_trues, y_scores = [], []
            losses = []

            batch: list[Tensor]
            for batch in getattr(self.data, f"{mode}_loader"):
                data = batch[0].to(self.device)
                mask = ~(data == -1).all(dim=-1)

                y_pred, loss, *_ = self.model.forward(data=data, padding_mask=mask)
                losses.append(loss.detach().cpu().numpy())

                # sub-setting to be pykt compatible
                y_pred = torch.masked_select(y_pred[:, 1:], mask[:, 1:])
                y_true = torch.masked_select(data[:, 1:, -1], mask[:, 1:])

                y_trues.append(y_true.detach().cpu().numpy())
                y_scores.append(y_pred.detach().cpu().numpy())

            y_trues = np.concatenate(y_trues, axis=0)
            y_scores = np.concatenate(y_scores, axis=0)
            y_pred = [1 if p >= 0.5 else 0 for p in y_scores]

            auc = metrics.roc_auc_score(y_true=y_trues, y_score=y_scores)
            acc = metrics.accuracy_score(y_true=y_trues, y_pred=y_pred)

        result = {
            f"{mode}_time": timer.get_duration_in_seconds(),
            f"{mode}_metric": auc,  # auc.item(),
            f"{mode}_auc": auc,  # auc.item(),
            f"{mode}_accuracy": acc,  # acc.item(),
            f"{mode}_epoch_mean_loss": float(
                np.mean(losses)
            ),  # TODO: Correct for last batch size
        }

        return result

    def log_results(
        self, result: dict[str, float], save_checkpoint: bool = False
    ) -> None:
        result.update(
            epoch=self.state.epoch,
            iteration=self.state.iteration,
            trainable_params=self.trainable_params,
        )

        if "val_metric" in result:
            # if iteration is best wrt. val_metric
            if result.get("val_metric", -np.inf) > self.state.best_val_metric:
                logger.info("Improvement! (resetting early stopping counter)")
                self.state.num_epochs_not_improved = 0

                self.state.best_iteration = self.state.iteration
                self.state.best_val_metric = result.get("val_metric", -np.inf)

                if save_checkpoint:
                    _ = self.save()
                    delete_old_checkpoints(
                        keep_iterations=self.state.get_checkpoint_iters()
                    )
            else:
                logger.info(
                    f"Increasing early stopping counter to {self.state.num_epochs_not_improved}: {result.get('val_metric', -np.inf)} not better than {self.state.best_val_metric}"
                )
                self.state.num_epochs_not_improved += 1
            result.update(
                {
                    f"best_iteration": self.state.best_iteration,
                    f"best_val_metric": self.state.best_val_metric,
                }
            )

        # write to json + log
        result_json: str = json.dumps(result, skipkeys=True)
        with open(file=RESULTS_FILE_NAME, mode="a") as file:
            file.write(f"{result_json}\n")
        if self.state.best_iteration == self.state.iteration:
            with open(file=BEST_FILE_NAME, mode="w") as file:
                file.write(result_json)
        logger.info(result_json)

        # log results to tensorboard
        for key, value in result.items():
            self.writer.add_scalar(
                tag=key, scalar_value=value, global_step=self.state.iteration
            )
        self.writer.flush()

    def save(self, checkpoint_dir: str = ".") -> Path:
        checkpoint_path = get_checkpoint_path_without_suffix(
            dir=checkpoint_dir, iteration=self.state.iteration
        ).with_suffix(".pth")
        checkpoint = dict(
            model_state_dict=self.model.state_dict(),
            optimizer_state_dict=self.optimizer.state_dict(),
        )
        torch.save(obj=checkpoint, f=checkpoint_path)
        state_dict = dataclasses.asdict(self.state)
        state_path = checkpoint_path.with_suffix(".yaml")
        with open(file=state_path, mode="w") as file:
            json.dump(obj=state_dict, fp=file)
        logger.info(
            f"Saved checkpoint (model, optimizer, trainer) to {checkpoint_dir} ..."
        )
        return checkpoint_path
