#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@title: Mitigating Barren Plateaus in Quantum Neural Networks via an AI-Driven Submartingale-Based Framework.
@topic: Train and Evaluation.
@author: anonymous
"""

import time
from typing import List, Dict, Callable
from collections import defaultdict
import torch
from torch import Tensor
from utils import compute_accuracy, remove_file, load_checkpoint, save_checkpoint


def train(model: torch.nn.Module, opt: torch.optim.Optimizer,
          train_data_loader: List[Tensor], val_data_loader: List[Tensor],
          num_epochs: int, dirs: str, device: str) -> Tensor:
    """
    @descriptions: Train the VQC and save the model parameters.
    @inputs:
        model: the quantum neural networks.
        opt: the optimizer.
        train_data_loader: the train data loader.
        val_data_loader: the validation data loader.
        num_epochs: the number of training epochs.
        dirs: the path for saving the checkpoints.
        device: cpu/mps/cuda.
    @return:
        loss_train_list: a list of train loss.
        grad_var: the variance of gradient from the train data.
        best_acc_val: the best validation accuracy.
    """
    # Define the loss function.
    loss_ce = torch.nn.CrossEntropyLoss()
    # Load checkpoint.
    try:
        model, opt, start_epoch, best_acc_val = \
                    load_checkpoint(f"{dirs}/model_best.pth.tar", model, opt)
    except:
        print("Failed to load the checkpoints.")
        start_epoch = 1
        best_acc_val = 0.0
    # Train with additional epochs if given #epochs < trained #epochs.
    if num_epochs <= start_epoch:
        num_epochs += start_epoch+1
    # Train the QNNs
    grad0_train = []
    loss_train_list = []
    for epoch in range(num_epochs):
        model.train()
        t0 = time.time()  # record the start runtime
        train_log = _batch_step(train_data_loader, model, opt, loss_ce, compute_accuracy, True, device)
        loss_train, acc_train = train_log['loss'], train_log['metric']
        loss_train_list.append(loss_train.item())
        t1 = time.time() - t0  # compute the training runtime in each epoch
        # Validation during training.
        with torch.no_grad():
            val_log = _batch_step(val_data_loader, model, opt, loss_ce, compute_accuracy, False, device)
            loss_val, acc_val = val_log['loss'], val_log['metric']
            print(f"Epoch: {epoch+1:4d} | Time(s): {t1:0.3f} | "
                  f"Loss_train: {loss_train.item():0.4f} | Loss_val: {loss_val.item():0.4f} | "
                  f"Acc_train: {acc_train:0.2%} | Acc_val: {acc_val:0.2%}")
        # Delete the previously existing parameter file (optional)
        remove_file(f"{dirs}/Epoch{epoch-1}.pth.tar")
        # Save the model checkpoints if we get better validation accuracy.
        if acc_val > best_acc_val:
            best_acc_val = acc_val
            is_best = True
            # Save checkpoint
            save_checkpoint(state = {
                                    'epoch': epoch + 1,
                                    'state_dict': model.state_dict(),
                                    'best_acc': best_acc_val,
                                    'optimizer': opt.state_dict()
                                    }, \
                            is_best = is_best, \
                            directory = dirs, \
                            filename = f"Epoch{epoch}.pth.tar"
                            )
        # Collect gradients from model.parameters()
        with torch.no_grad():
            # get the 1st-layer params
            params = list(model.parameters())[0]  # (nlayers, nqubits, nrot)
            if params.grad is not None:
                # get the grad of the 1st-element in the 1st-layer params.
                grad0 = params.grad.reshape(-1)[0]
                grad0_train.append(grad0)

    # Calculate the gradient variance
    grad_var = torch.var(torch.stack(grad0_train))
    return loss_train_list, grad_var.item(), best_acc_val  # list, double, float

def _batch_step(data_loader: List[Tensor],
                model: torch.nn.Module, opt: torch.optim.Optimizer,
                loss_func: Callable, eval_func: Callable,
                is_grad: bool, device: str) -> Dict:
    """
    @descriptions: an optimization step in batches.
    @inputs:
        data_loader: a dataloader that contains data and labels.
        model: the quantum neural networks.
        opt: the optimizer.
        loss_func: the loss function.
        eval_func: the evaluation function.
        is_grad: require grad or not.
        device: cpu/mps/cuda.
    @return:
        log_dict: a dictionary that stores loss and evaluation results.
    """
    num_batch = 0
    log_dict = defaultdict(float)
    log_dict['loss'] = torch.tensor(0.0, dtype=torch.float64, device=device)
    log_dict['metric'] = 0.0
    for data_batch, labels_batch in data_loader:
        num_batch += 1
        if torch.cuda.is_available() and device == 'cuda':
            data_batch = data_batch.to(device)
            labels_batch = labels_batch.to(device)
        # Output the logits
        logits_batch = model(data_batch)
        # Compute the loss and metric
        loss = loss_func(logits_batch, labels_batch)  # double_tensor
        metric = eval_func(logits_batch, labels_batch)  # double
        # Backward
        if is_grad:
            opt.zero_grad()  # reset the gradient to avoid accumulation.
            loss.backward()  # compute the gradient
            opt.step()  # update the weights (W_{new} = W_{old} - ∂(Gradient))
        # Store the loss and metric
        log_dict['loss'] += loss
        log_dict['metric'] += metric
    num_batch = 1 if num_batch < 1 else num_batch
    return {k: v/num_batch for k, v in log_dict.items()}

def evaluation(model: torch.nn.Module, opt: torch.optim.Optimizer, path: str, 
               data_loader: List[Tensor], device: str) -> None:
    """
    @descriptions: Evaluation on the given model.
    @inputs:
        model: the quantum neural networks.
        opt: the optimizer.
        path: the file path for the model.
        data_loader: the test data loader.
        device: cpu/mps/cuda.
    @return:
        print out the test acc.
    """
    model.eval()  # fix the DO & BN during evaluation
    model, opt, _, _ = load_checkpoint(path, model, opt)
    loss_ce = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        test_log = _batch_step(data_loader, model, opt, loss_ce, compute_accuracy, False, device)
    acc_test = test_log['metric']
    print(f"Best Testing Accuracy: {acc_test:.2%}.")
