import dataclasses
import logging
import time
from typing import Dict, Any, Tuple, Optional

import numpy as np
import torch
from . import torch_geometric
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.utils.data import DataLoader
from torch_ema import ExponentialMovingAverage
from tqdm.auto import tqdm

from .checkpoint import CheckpointHandler, CheckpointState
from .torch_tools import to_numpy, tensor_dict_to_device
from .utils import MetricsLogger, compute_mae, compute_rmse, compute_q95


@dataclasses.dataclass
class SWAContainer:
    model: AveragedModel
    scheduler: SWALR
    start: int


def train(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler.ExponentialLR,
    start_epoch: int,
    max_num_epochs: int,
    patience: int,
    checkpoint_handler: CheckpointHandler,
    logger: MetricsLogger,
    eval_interval: int,
    device: torch.device,
    swa: Optional[SWAContainer] = None,
    ema: Optional[ExponentialMovingAverage] = None,
    *,
    progress_bar: bool = True,
    wdb_logger: Optional[Dict[str, Any]] = None,
):
    lowest_loss = np.inf
    patience_counter = 0
    logging.info("Started training")
    if progress_bar:
        iterator = tqdm(range(start_epoch, max_num_epochs))
    else:
        iterator = range(start_epoch, max_num_epochs)
    for epoch in iterator:
        # Train
        for batch in train_loader:
            avg_loss, opt_metrics = take_step(
                model=model,
                loss_fn=loss_fn,
                batch=batch,
                optimizer=optimizer,
                ema=ema,
                device=device,
            )
            opt_metrics["mode"] = "opt"
            opt_metrics["epoch"] = epoch
            logger.log(opt_metrics)
            if wdb_logger is not None:
                wdb_logger.log({"train_loss": avg_loss})

        # Validate
        if epoch % eval_interval == 0:
            if ema is not None:
                with ema.average_parameters():
                    valid_loss, eval_metrics = evaluate(
                        model=model,
                        loss_fn=loss_fn,
                        data_loader=valid_loader,
                        device=device,
                    )
            else:
                valid_loss, eval_metrics = evaluate(
                    model=model,
                    loss_fn=loss_fn,
                    data_loader=valid_loader,
                    device=device,
                )
            eval_metrics["mode"] = "eval"
            eval_metrics["epoch"] = epoch
            accuracy = eval_metrics["accuracy"]

            if wdb_logger is not None:
                wdb_logger.log({"epoch": epoch, "eval-accuracy": accuracy,
                                "eval_loss": valid_loss})

            logging.info(
                f"Epoch {epoch}: loss={valid_loss:.4f}, accuracy={accuracy:.2f}"
            )

            if valid_loss >= lowest_loss:
                patience_counter += 1
                if patience_counter >= patience:
                    logging.info(
                        f"Stopping optimization after {patience_counter} epochs without improvement"
                    )
                    break
            else:
                lowest_loss = valid_loss
                patience_counter = 0
                if ema is not None:
                    with ema.average_parameters():
                        checkpoint_handler.save(
                            state=CheckpointState(model, optimizer, lr_scheduler),
                            epochs=epoch,
                        )
                else:
                    checkpoint_handler.save(
                        state=CheckpointState(model, optimizer, lr_scheduler),
                        epochs=epoch,
                    )

        # LR scheduler and SWA update
        if swa is None or epoch < swa.start:
            lr_scheduler.step(valid_loss)  # Can break if exponential LR, TODO fix that!
        else:
            swa.model.update_parameters(model)
            swa.scheduler.step()

    logging.info("Training complete")


def take_step(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    batch: torch_geometric.batch.Batch,
    optimizer: torch.optim.Optimizer,
    ema: Optional[ExponentialMovingAverage],
    device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
    start_time = time.time()
    batch = batch.to(device)
    optimizer.zero_grad()
    output = model(batch, training=True)
    loss = loss_fn(pred=output, ref=batch)
    loss.backward()
    optimizer.step()

    if ema is not None:
        ema.update()

    acc = (
        torch.argmax(output["energy"].softmax(-1), dim=-1) == batch["signal"]
    ).float().mean() * 100
    # print("accuracy", acc)
    loss_dict = {
        "loss": to_numpy(loss),
        "acc": to_numpy(acc),
        "time": time.time() - start_time,
    }
    print("loss_dict", loss_dict)
    return loss, loss_dict




def warmup(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    rounds: int,
    lrs: list,
    logger: MetricsLogger,
    eval_interval: int,
    device: torch.device,
    *,
    progress_bar: bool = True,
    wdb_logger: Optional[Dict[str, Any]] = None,
    warm_strategy: str = "linear",
):
    if warm_strategy == "linear":
        slope = (lrs[-1] - lrs[0]) / rounds
        assert len(lrs) == 2, "Linear warmup requires two LR values"
        assert slope > 0, "Initial LR must be smaller than final LR"
        f_linear = lambda epoch: lrs[0] + (slope * epoch)
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer=optimizer,
            lr_lambda=f_linear,
            last_epoch=-1,
            verbose=True,)

    else:
        raise ValueError(f"Unknown warmup strategy {warm_strategy}")

    logging.info("Started warmup for {} epochs".format(rounds))
    if progress_bar:
        iterator = tqdm(range(rounds))
    else:
        iterator = range(rounds)
    for epoch in iterator:
        # Train
        for batch in train_loader:
            avg_loss, opt_metrics = take_step(
                model=model,
                loss_fn=loss_fn,
                batch=batch,
                optimizer=optimizer,
                ema=None,
                device=device,
            )
            opt_metrics["mode"] = "opt"
            opt_metrics["epoch"] = epoch
            logger.log(opt_metrics)

        # Validate
        if epoch % eval_interval == 0:
            valid_loss, eval_metrics = evaluate(
                model=model,
                loss_fn=loss_fn,
                data_loader=valid_loader,
                device=device,
            )
            eval_metrics["mode"] = "eval"
            eval_metrics["epoch"] = epoch
            accuracy = eval_metrics["accuracy"]
            if wdb_logger is not None:
                wdb_logger.log({"epoch": epoch, "eval-accuracy": accuracy,
                                "eval_loss": valid_loss, "train_loss": avg_loss})

            logging.info(
                f"Epoch {epoch}: loss={valid_loss:.4f}, accuracy={accuracy:.1f}"
            )
        lr_scheduler.step(valid_loss)
    logging.info("Warmup complete")


def evaluate(
    model: torch.nn.Module,
    loss_fn: torch.nn.Module,
    data_loader: DataLoader,
    device: torch.device,
) -> Tuple[float, Dict[str, Any]]:
    total_loss = 0.0
    start_time = time.time()
    accuracies_list = []
    model.eval()
    for batch in data_loader:
        batch = batch.to(device)
        output = model(batch, training=False)
        batch = batch.cpu()
        output = tensor_dict_to_device(output, device=torch.device("cpu"))

        loss = loss_fn(pred=output, ref=batch)
        total_loss += to_numpy(loss).item()
        accuracies_list.append(
            (
                torch.argmax(output["energy"].softmax(-1), dim=-1) == batch["signal"]
            ).float()
        )

    accuracies = torch.cat(accuracies_list, dim=0)
    avg_acc = accuracies.mean() * 100
    avg_loss = total_loss / len(data_loader)
    model.train()
    aux = {
        "loss": avg_loss,
        "accuracy": to_numpy(avg_acc),
        "time": time.time() - start_time,
    }
    logging.info(f"Evaluation: loss={avg_loss:.4f}, accuracy={avg_acc:.2f}")
    print("EVALUATION", aux)
    return avg_loss, aux
