import os
import random
from collections import defaultdict
from typing import Optional

import numpy as np
import torch
import wandb
from tqdm.autonotebook import tqdm

from zarya.metrics import average_metric_over_loaders, flatten_metrics, get_best_metrics


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler,
        num_patience_steps: Optional[int],
        fp16: bool,
        device: torch.device,
        trainable_params: int,
    ):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.grad_scaler = torch.cuda.amp.GradScaler() if fp16 else None

        self.device = device

        self.num_patience_steps = num_patience_steps
        self.trainable_params = trainable_params

    def train_model(
        self,
        dataloaders,
        epochs,
        metric_fn,
        dataset_name,
        task_type,
        debug: bool = False,
        best_model_save_path: Optional[str] = None,
    ):

        best_metrics = None
        num_steps_without_improvement = 0

        for epoch in range(epochs):
            epoch_metrics = defaultdict(lambda: dict())
            epoch_time = 0
            for key, loader in dataloaders.items():
                pbar = tqdm(loader, leave=False)
                pbar.set_description(f"{key} epoch {epoch + 1} of {epochs}")

                is_training_phase = "train" in key
                self._set_train_mode(is_training_phase)

                for idx, batch in enumerate(pbar):
                    loss, time = self._step(
                        batch, metric_fn, task_type, is_training_phase
                    )
                    pbar.set_description(
                        f"{key} epoch {epoch} of {epochs}: loss {loss.item():.4f}"
                    )
                    if debug and idx >= 3:
                        break

                if not is_training_phase:
                    epoch_time += time

                if not is_training_phase:
                    loader_metrics = metric_fn.compute()
                    for k, v in loader_metrics.items():
                        epoch_metrics[k][key] = v

            best_metrics, new_best_model_achieved = self._update_best_metrics(
                epoch,
                epoch_time,
                epoch_metrics,
                best_metrics,
                dataset_name,
                best_model_save_path,
            )

            if self.num_patience_steps is not None and not new_best_model_achieved:
                num_steps_without_improvement += 1
                if num_steps_without_improvement > self.num_patience_steps:
                    break
            else:
                num_steps_without_improvement = 0

    def test(self, dataloaders, is_regression):
        self.model.to(self.device)
        self._set_train_mode(False)
        predictions = defaultdict(lambda: [])
        for loader_key, loader in dataloaders.items():
            for batch in loader:
                batch.pop("labels")
                logits = self.model(**self.any2device(batch, self.device))[0]
                current_predictions = (
                    logits.view(-1) if is_regression else logits.argmax(-1)
                )
                current_predictions = current_predictions.cpu().to_list()
                predictions[loader_key] += current_predictions
        return predictions

    def _update_best_metrics(
        self,
        epoch,
        epoch_time,
        epoch_metrics,
        best_metrics,
        dataset_name,
        best_model_save_path,
    ):
        val_metrics = average_metric_over_loaders(
            epoch_metrics=epoch_metrics,
            pattern="valid",
            dataset_name=dataset_name,
        )

        best_metrics, new_best_model_achieved = get_best_metrics(
            best_metrics=best_metrics,
            metrics=val_metrics,
        )

        if epoch == 0:
            best_metrics["trainable_params"] = self.trainable_params

        if epoch_time is not None:
            best_metrics["validation_time"] = epoch_time

        if best_model_save_path is not None and new_best_model_achieved:
            self._save_model(best_model_save_path)

        self._log_metrics(epoch_metrics, best_metrics)

        return best_metrics, new_best_model_achieved

    @staticmethod
    def _log_metrics(epoch_metrics, best_metrics):
        metrics_to_log = flatten_metrics(epoch_metrics)
        metrics_to_log = {**metrics_to_log, **best_metrics}
        wandb.log(metrics_to_log)

    def _set_train_mode(self, train_mode=True):
        if train_mode:
            self.model.train()
            self.optimizer.zero_grad()
            torch.set_grad_enabled(True)
        else:
            self.model.eval()
            torch.set_grad_enabled(False)

    def _step(self, batch, metric_fn, task_type, is_training_phase):
        batch = self.any2device(batch, self.device)
        with torch.cuda.amp.autocast(enabled=self.grad_scaler is not None):
            loss, _, time = self._eval_loss(
                batch, metric_fn, task_type, is_training_phase
            )

        if is_training_phase:
            self._optimize_loss(loss)

        return loss, time

    def _eval_loss(self, batch, metric, task_type, is_training_phase):
        if not is_training_phase:
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            outputs = self.model(**batch)
            end.record()
            torch.cuda.synchronize()
        else:
            outputs = self.model(**batch)

        loss = outputs[0]
        if not is_training_phase:
            metric.add_batch(
                predictions=outputs[1].argmax(-1)
                if task_type != "regression"
                else outputs[1].squeeze(-1),
                references=batch["labels"],
            )

        return loss, outputs, start.elapsed_time(end) if not is_training_phase else 0

    def _optimize_loss(self, loss):
        if self.grad_scaler is not None:
            self.grad_scaler.scale(loss).backward()
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            loss.backward()
            self.optimizer.step()

        self.optimizer.zero_grad()

        if self.scheduler is not None:
            self.scheduler.step()

    def _save_model(self, path):
        torch.save(self.model.state_dict(), path)

    @staticmethod
    def any2device(value, device):
        """
        Move tensor, list of tensors, list of list of tensors,
        dict of tensors, tuple of tensors to target device.
        Args:
            value: Object to be moved
            device: target device ids
        Returns:
            Same structure as value, but all tensors and np.arrays moved to device
        """
        if isinstance(value, dict):
            return {k: Trainer.any2device(v, device) for k, v in value.items()}
        elif isinstance(value, (tuple, list)):
            return [Trainer.any2device(v, device) for v in value]
        elif torch.is_tensor(value):
            return value.to(device, non_blocking=True)
        return


def set_deterministic_mode(seed):
    _set_seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.deterministic = True
    torch.backends.benchmark = False

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def _set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
