#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@title: Mitigating Barren Plateaus in Quantum Neural Networks via an AI-Driven Submartingale-Based Framework.
@topic: Basic initialization-based strategies.
@author: anonymous
"""

import argparse
import math
import numpy as np
import scipy.stats as ss
import torch
from torch import Tensor
import pennylane as qml
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
from model import circuit1, circuit2, circuit3, QuantumModel
from train_eval import train, evaluation
from utils import read_yaml_file, load_dataset, batch_to_full_data, data_batch_loader, \
            preprocessing_titanic, preprocessing_mnist, check_mkdirs, init_model_params


# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--data_name', type=str, default='mnist', help='The name of dataset.',
                    choices=['iris', 'wine', 'titanic', 'mnist'])
parser.add_argument("--vqc_type", type=int, default=1, help='The type of VQC.')
parser.add_argument('--init', type=str, default='uniform', help='Initialization methods.',
                    choices=['uniform', 'normal', 'beta', 'glorot_uniform', 'glorot_normal', 'he_uniform', 'he_normal', 'orthogonal'])
parser.add_argument('--opt', type=str, default='adam', help='The optimizer for training.',
                    choices=['adam', 'sgd'])
parser.add_argument('--epochs', type=int, default=30, help='The number of training epochs.')
parser.add_argument('--batch_size', type=int, default=20, help='The size of one batch.')
parser.add_argument('--lr', type=float, default=0.1, help='The learning rate for optimization.')
parser.add_argument('--nlayers', type=int, default=4, help='The number of layers in VQC.')
parser.add_argument('--nqubits', type=int, default=4, help='The number of qubits in VQC.')
parser.add_argument('--is_prior', type=bool, default=True, help='Consider data information as a prior or not.')
parser.add_argument('--qml_dev', type=str, default='default.qubit.torch', help='qml device name.') # https://github.com/PennyLaneAI/pennylane/pull/1982
parser.add_argument('--device', type=str, default='cuda', help='device name.',
                    choices=['cpu', 'cuda', 'mps'])
parser.add_argument('--GPU', type=int, default=0, help='gpu id.')
parser.add_argument("--config_file", type=str, default='model', help='The name of config files.',
                    choices=['model', ''])
args = parser.parse_known_args()[0]

# Read config files
if len(args.config_file) > 0:
    config_file = read_yaml_file("../config", args.config_file)
    train_config = config_file['search']
    args.data_name = train_config['data_name']
    args.vqc_type = train_config['vqc_type']
    args.init = train_config['init']
    args.opt = train_config['opt']
    args.epochs = train_config['epochs']
    args.batch_size = train_config['batch_size']
    args.lr = train_config['lr']
    args.nlayers = train_config['nlayers']
    args.nqubits = train_config['nqubits']
    args.is_prior = train_config['is_prior']
    args.qml_dev = train_config['qml_dev']
    args.device = train_config['device']
    args.GPU = train_config['GPU']
assert args.nlayers >= 1 and args.nqubits >= 2


# Initialize the hyper-parameters
SEED_ID = 42
np.random.seed(SEED_ID)
TRAIN_RATIO, TRAIN_RATIO_WINE, VAL_RATIO_IRIS = 0.8, 0.77, 0.75  # the ratio for data split
class_ids = [0, 1]  # class id
N_SAMPLES = 200  # the number of samples for each class
NUM_ROT = 3  # The number of parameters (phi, theta, omega) for rotation in VQC
file_name = {'iris': 'iris_2class_scaled.npy',
            'wine': 'wine_2class_scaled.npy',
            'titanic': 'titanic_2class_raw.csv',
            'mnist': ''
            }
circuits = {1: circuit1, 2: circuit2, 3: circuit3}
optimizers = {'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam
            }
# Set up the paths
data_path = "../data"
checkpoint_file = f"ckpt_{args.data_name}_vqc{args.vqc_type}_{args.opt}_{args.init}_"\
                  f"{args.nlayers}ly_{args.nqubits}qb"
checkpoint_path = f"./checkpoints/{checkpoint_file}"
model_path = f"{checkpoint_path}/model_best.pth.tar"  # for eval
# Instantiate: default.qubit | default.qubit.torch | lightning.gpu
# ref: https://docs.pennylane.ai/en/stable/code/api/pennylane.device.html
qml_dev = qml.device(args.qml_dev, wires=args.nqubits)

# Consider train data as prior distribution for initializations:
# data-dependent initializations v.s. data-independent initializations

def uniform_init(params: Tensor, data: Tensor, is_prior: bool) -> Tensor:
    """Uniform initialization."""
    if is_prior:  # consider data information as prior distribution.
        l, h = data.min().item(), data.max().item()
    else:  # assume uniform distribution in [0, 1]
        l, h = 0.0, 1.0
    print(f"The range of uniform distribution -- [{l:.2f}, {h:.2f}]")
    return params.uniform_(l, h)

def _std_uniform_init(params: Tensor, data: Tensor, is_prior: bool, std: float) -> Tensor:
    """Build uniform initialization based on a given std."""
    assert std is not None
    if is_prior:
        l, h = data.min().item(), data.max().item()
        l, h = l*std, h*std
    else:
        l, h = -std, std
    return params.uniform_(l, h)

def normal_init(params: Tensor, data: Tensor, is_prior: bool) -> Tensor:
    """Normal initialization."""
    if is_prior:  # consider data information as prior distribution.
        mu, sigma = torch.mean(data).item(), torch.std(data).item()
    else:  # assume standard normal distribution.
        mu, sigma = 0.0, 1.0
    print(f"Parameters for normal distribution -- mu: {mu:.2f}, sigma: {sigma:.2f}.")
    return params.normal_(mean=mu, std=sigma)

def _std_normal_init(params: Tensor, data: Tensor, is_prior: bool, std: float) -> Tensor:
    """Build normal initialization based on a given std."""
    assert std is not None
    if is_prior:  # consider data information as prior distribution.
        mu, sigma = torch.mean(data).item(), torch.std(data).item()
    else:  # assume standard normal distribution.
        mu, sigma = 0.0, 1.0
    sigma *= std  # update the sigma
    return params.normal_(mean=mu, std=sigma)

def beta_init(params: Tensor, data: Tensor, is_prior: bool) -> Tensor:
    """Beta initialization."""
    if is_prior:  # consider data information as prior distribution.
        try:
            a, b, _, _ = ss.beta.fit(data, floc=0, fscale=1)
        except:
            a, b = 0.5, 0.5
    else:
        a, b = 0.5, 0.5
    print(f"Parameters for beta distribution -- alpha: {a:.2f}, beta: {b:.2f}.")
    return torch.distributions.Beta(a, b).sample(params.shape).double()

def glorot_uniform_init(params: Tensor, data: Tensor, is_prior: bool, gain: float = 1.) -> Tensor:
    """Glorot uniform initialization: [-x, x], where x=√(6/(Fi+Fo))."""
    fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(params)
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    std *= math.sqrt(3.0)  # Calculate uniform bounds from standard deviation  
    print(f"Using glorot uniform initialization: fan_in {fan_in}, fan_out {fan_out}.")
    return _std_uniform_init(params, data, is_prior, std)

def glorot_normal_init(params: Tensor, data: Tensor, is_prior: bool, gain: float = 1.) -> Tensor:
    """Glorot normal initialization: N(0, std), where std = x=√(2/(Fi+Fo))."""
    fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(params)
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    print(f"Using glorot normal initialization: fan_in {fan_in}, fan_out {fan_out}.")
    return _std_normal_init(params, data, is_prior, std)

def he_uniform_init(params: Tensor, data: Tensor, is_prior: bool, nl: str = 'relu') -> Tensor:
    """He uniform initialization: [-x, x], where x=√(6/Fi)."""
    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(params)
    gain = torch.nn.init.calculate_gain(nonlinearity=nl)
    std = gain / math.sqrt(float(fan_in))
    std *= math.sqrt(3.0)  # Calculate uniform bounds from standard deviation   
    print(f"Using he uniform initialization: fan_in {fan_in}.")
    return _std_uniform_init(params, data, is_prior, std)

def he_normal_init(params: Tensor, data: Tensor, is_prior: bool, nl: str = 'relu') -> Tensor:
    """He normal initialization: N(0, std), where std = x=√(2/Fi)."""
    fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(params)
    gain = torch.nn.init.calculate_gain(nonlinearity=nl)
    std = gain / math.sqrt(float(fan_in))
    print(f"Using he normal initialization: fan_in {fan_in}.")
    return _std_normal_init(params, data, is_prior, std)

def orthogonal_init(params: Tensor, data: Tensor, is_prior: bool, gain: float = 1.) -> Tensor:
    """Orthogonal initialization (mostly from pytorch)."""
    if params.ndimension() < 2:
        raise ValueError("Only params with 2 or more dimensions are supported.")
    if params.numel() == 0:
        return params  # no-op
    rows = params.size(0)
    cols = params.numel() // rows
    # Initialize normal distribution based on a given prior.
    params_2d = normal_init(params.new(rows, cols), data, is_prior)
    if rows < cols:
        params_2d.t_()
    # Compute the qr factorization
    q, r = torch.linalg.qr(params_2d)
    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
    d = torch.diag(r, 0)
    ph = d.sign()
    q *= ph
    if rows < cols:
        q.t_()
    params.view_as(q).copy_(q)
    params.mul_(gain)
    print("Using orthogonal initialization.")
    return params


if __name__ == '__main__':
    print("args: ", args)
    # Load dataset and preprocessing
    print("Load dataset and preprocessing:")
    if args.data_name in ['iris', 'wine', 'titanic']:
        data, labels = load_dataset(args.data_name, data_path, file_name[args.data_name])
        SPLIT_RATIO = TRAIN_RATIO_WINE if args.data_name == 'wine' else TRAIN_RATIO  # wine: 80/20/30
        X_train, X_test, y_train, y_test = train_test_split(data, labels, train_size=SPLIT_RATIO)
        if args.data_name == 'titanic':
            X_train, y_train = preprocessing_titanic(X_train, y_train, class_ids, N_SAMPLES)
            X_test, y_test = preprocessing_titanic(X_test, y_test, class_ids, N_SAMPLES)
        # Reduce feature dimension s.t. nfeat = nqubits.
        if len(X_train.shape) == 2 and X_train.shape[1] > args.nqubits:
            X_train = TSNE(n_components=args.nqubits, random_state=SEED_ID, 
                           n_jobs=2, method='exact').fit_transform(X_train)
        if len(X_test.shape) == 2 and X_test.shape[1] > args.nqubits:
            X_test = TSNE(n_components=args.nqubits, perplexity=X_test.shape[0]-5, 
                          random_state=SEED_ID, n_jobs=2, method='exact').\
                            fit_transform(X_test)
    if args.data_name == 'mnist':
        train_data, test_data = load_dataset(args.data_name, data_path, file_name[args.data_name])
        RSIZE = int(np.sqrt(args.nqubits))
        X_train, y_train = preprocessing_mnist(train_data, class_ids, N_SAMPLES, RSIZE)
        X_test, y_test = preprocessing_mnist(test_data, class_ids, N_SAMPLES, RSIZE)
    # Apply label encoding if the label ids don't follow the order.
    if max(class_ids)+1 != len(class_ids):
        label_encoder = LabelEncoder()
        y_train, y_test = label_encoder.fit_transform(y_train), label_encoder.fit_transform(y_test)
    # Apply one-hot encoding to the encoded labels
    onehot_encoder = OneHotEncoder(sparse_output=False)
    y_train, y_test = y_train.reshape(-1, 1), y_test.reshape(-1, 1)
    y_train, y_test = onehot_encoder.fit_transform(y_train), onehot_encoder.fit_transform(y_test)
    # Split the validation set from the train data.
    SPLIT_RATIO = VAL_RATIO_IRIS if args.data_name == 'iris' else TRAIN_RATIO  # iris: 60/20/20
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, train_size=SPLIT_RATIO)
    # Convert to torch tensor
    X_train, X_val, X_test = torch.DoubleTensor(X_train), torch.DoubleTensor(X_val), torch.DoubleTensor(X_test)
    y_train, y_val, y_test = torch.DoubleTensor(y_train), torch.DoubleTensor(y_val), torch.DoubleTensor(y_test)
    # Build dataloaders for batch training
    train_data_loader = data_batch_loader(X_train, y_train, args.batch_size)
    val_data_loader = data_batch_loader(X_val, y_val, args.batch_size)
    test_data_loader = data_batch_loader(X_test, y_test, args.batch_size)

    print("Start training ...")
    check_mkdirs(checkpoint_path)
    params_shapes = (args.nlayers, args.nqubits, NUM_ROT)
    QModel = QuantumModel(circuits[args.vqc_type], qml_dev, params_shapes, len(class_ids))
    # Initialize the model parameters with basic distributions
    init_funcs = {'uniform': uniform_init,
                'normal': normal_init,
                'beta': beta_init,
                'glorot_uniform': glorot_uniform_init,
                'glorot_normal': glorot_normal_init,
                'he_uniform': he_uniform_init,
                'he_normal': he_normal_init,
                'orthogonal': orthogonal_init
                }
    X_train_full, _ = batch_to_full_data(train_data_loader, args.device)
    init_model_params(QModel, init_funcs[args.init], X_train_full, args.is_prior, args.device)  # cuda
    # Define the optimizers
    opt = optimizers[args.opt](QModel.parameters(), lr=args.lr)
    # Setup the device if necessary
    if torch.cuda.is_available() and args.device == 'cuda':
        torch.cuda.set_device(args.GPU)
        QModel.to(args.device)

    # Train the VQC
    loss_train, grad_var, _ = train(QModel, opt, train_data_loader, val_data_loader, 
                                    args.epochs, checkpoint_path, args.device)
    print(f"The gradiant variance: {grad_var}")

    # Evaluation
    print("Evaluation:")
    evaluation(QModel, opt, model_path, test_data_loader, args.device)
    torch.cuda.empty_cache()
