#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@title: Mitigating Barren Plateaus in Quantum Neural Networks via an AI-Driven Submartingale-Based Framework.
@topic: A GenAI-driven framework for generating effective VQCs' initial parameters.
@author: anonymous
@instructions for using Vertex AI:
    gcloud init
    gcloud auth application-default login
"""

import argparse
from typing import List, Union
import numpy as np
import torch
from torch import Tensor
import pennylane as qml
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
from openai import OpenAI
#import vertexai
from vertexai.generative_models import GenerativeModel
from model import circuit1, circuit2, circuit3, QuantumModel
from train_eval import train, evaluation
from utils import read_yaml_file, load_dataset, data_batch_loader, \
            preprocessing_titanic, preprocessing_mnist, check_mkdirs, save_files, \
            generate_init_params, expected_improvement, update_model_params

# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--data_name', type=str, default='iris', help='The name of dataset.',
                    choices=['iris', 'wine', 'titanic', 'mnist'])
parser.add_argument("--vqc_type", type=int, default=1, help='The type of VQC.')
parser.add_argument('--init', type=str, default='uniform', help='Initialization methods.',
                    choices=['uniform', 'normal', 'beta'])
parser.add_argument('--opt', type=str, default='adam', help='The optimizer for training.',
                    choices=['adam', 'sgd'])
parser.add_argument('--nsearches', type=int, default=50, help='The number of iterations.')
parser.add_argument('--epochs', type=int, default=30, help='The number of training epochs.')
parser.add_argument('--batch_size', type=int, default=20, help='The size of one batch.')
parser.add_argument('--lr', type=float, default=0.1, help='The learning rate for optimization.')
parser.add_argument('--nlayers', type=int, default=4, help='The number of layers in VQC.')
parser.add_argument('--nqubits', type=int, default=4, help='The number of qubits in VQC.')
parser.add_argument('--vertexai_proj_id', type=str, default='', help='Project id for VertexAI.init().')
parser.add_argument('--vertexai_region', type=str, default='', help='Region for VertexAI.init().')
parser.add_argument('--openai_key', type=str, default='', help='API Key for OpenAI models.')
parser.add_argument('--genai_name', type=str, default='gemini-1.5-flash-002', help='GenAI model name.')
parser.add_argument('--temp', type=float, default=0.5, help='The randomness in token selection.')
parser.add_argument('--top_p', type=float, default=0.95, help='How the model selects tokens for output.')
parser.add_argument('--qml_dev', type=str, default='default.qubit.torch', help='qml device name.') # https://github.com/PennyLaneAI/pennylane/pull/1982
parser.add_argument('--device', type=str, default='cuda', help='device name.',
                    choices=['cpu', 'cuda', 'mps'])
parser.add_argument('--GPU', type=int, default=0, help='gpu id.')
parser.add_argument("--config_file", type=str, default='model', help='The name of config files.',
                    choices=['model', ''])
args = parser.parse_known_args()[0]

# Read config files
if len(args.config_file) > 0:
    config_file = read_yaml_file("../config", args.config_file)
    main_config = config_file['search']
    args.data_name = main_config['data_name']
    args.vqc_type = main_config['vqc_type']
    args.init = main_config['init']
    args.opt = main_config['opt']
    args.nsearches = main_config['nsearches']
    args.epochs = main_config['epochs']
    args.batch_size = main_config['batch_size']
    args.lr = main_config['lr']
    args.nlayers = main_config['nlayers']
    args.nqubits = main_config['nqubits']
    args.vertexai_proj_id = main_config['vertexai_proj_id']
    args.vertexai_region = main_config['vertexai_region']
    args.openai_key = main_config['openai_key']
    args.genai_name = main_config['genai_name']
    args.temp = main_config['temp']
    args.top_p = main_config['top_p']
    args.qml_dev = main_config['qml_dev']
    args.device = main_config['device']
    args.GPU = main_config['GPU']
assert args.nlayers >= 1 and args.nqubits >= 2


# Initialize the hyper-parameters
SEED_ID = 42
np.random.seed(SEED_ID)
TRAIN_RATIO, TRAIN_RATIO_WINE, VAL_RATIO_IRIS = 0.8, 0.77, 0.75  # the ratio for data split
class_ids = [0, 1]  # class id
N_SAMPLES = 200  # the number of samples for each class
NUM_ROT = 3  # The number of parameters (phi, theta, omega) for rotation in VQC
file_name = {'iris': 'iris_2class_scaled.npy',
            'wine': 'wine_2class_scaled.npy',
            'titanic': 'titanic_2class_raw.csv',
            'mnist': ''
            }
circuits = {1: circuit1, 2: circuit2, 3: circuit3}
optimizers = {'sgd': torch.optim.SGD,
            'adam': torch.optim.Adam
            }
pdict = {'nlayers': args.nlayers,
        'nqubits': args.nqubits,
        'nrot': NUM_ROT,
        'nclasses': len(class_ids),
        'data_card': "N/A",
        'init': args.init,
        'feedback': "N/A",
        } # prompts' arguments
genai_config = {'max_output_tokens': 8192,
                'temperature': args.temp,
                'top_p': args.top_p,
                }
# Set up the paths
data_path = "../data"
output_path = "./parameters"
check_mkdirs(output_path)
checkpoint_file = f"ckpt_{args.data_name}_vqc{args.vqc_type}_{args.opt}_{args.init}_"\
                  f"{args.nlayers}ly_{args.nqubits}qb"
init_params_file = f"params_{checkpoint_file[5:]}.pkl"
loss_file = f"loss_{checkpoint_file[5:]}.pkl"
checkpoint_path = f"./checkpoints/{checkpoint_file}"
init_params_path = f"{output_path}/{init_params_file}"
loss_path = f"{output_path}/{loss_file}"
# Instantiate: default.qubit | default.qubit.torch | lightning.gpu
# ref: https://docs.pennylane.ai/en/stable/code/api/pennylane.device.html
qml_dev = qml.device(args.qml_dev, wires=args.nqubits)
#vertexai.init(project=args.vertexai_proj_id, location=args.vertexai_region)  # init() if necessary


def build_genai_model(genai_name: str) -> Union[OpenAI, GenerativeModel]:
    """Initialize a Generative Model."""
    if 'gpt' in genai_name: # build GPT models
        genai_model = OpenAI(api_key=args.openai_key)
    else: # build Vertex AI models
        genai_model = GenerativeModel(genai_name)
    return genai_model


def generate_init_params(train_data_loader: List[Tensor],
                       val_data_loader: List[Tensor],
                       test_data_loader: List[Tensor],
                       nsearch: int, nlayers: int, nqubits: int,
                       pdict: dict, genai_name: str, genai_config: dict,
                       dirs: str, device: str) -> dict:
    """
    @descriptions: Generate a VQC's effective initial parameters using a GenAI model.
    @inputs:
        train_data_loader: the train data loader.
        val_data_loader: the validation data loader.
        test_data_loader: the test data loader.
        nsearch: the number of searches for initial parameters.
        nlayers: the number of layers in VQCs.
        nqubits: the number of qubits in VQCs.
        pdict: the config for generating/updating prompts.
        genai_name: the name of the GenAI model.
        genai_config: the config for the GenAI model.
        dirs: the path for saving the checkpoints.
        device: cpu/mps/cuda.
    @return:
        best_init_params: a dict of VQC's effective initial parameters.
        best_loss_curves: a dict of loss curves.
    """
    prompts = generate_prompts(pdict) # Initialize prompts
    genai_model = build_genai_model(genai_name) # Initialize a Generative AI Model
    
    lb = (nsearch*nqubits**6)**(-1) # Assumed theoretical lower bound per iteration    
    best_grad_var = 0.0
    best_init_params = {}
    best_loss_curves = {}
    print("Start searching ...")
    for search in range(nsearch):
        print(f"Initialize VQCs' parameters using the GenAI model for the {search}-th iteration.")
        init_params_dict = generate_init_params(genai_model, genai_name, genai_config, prompts)
        # size: [[nlayers, nqubits, NUM_ROT], [nlayers, nqubits], [out_dim]] # 用dict存参数
        init_params = {k: torch.tensor(v, dtype=torch.float64) 
                       for k, v in init_params_dict.items()}

        # Training
        print("Start training ...")
        ckpt_srch_path = f"{dirs}/srch{search}"
        check_mkdirs(ckpt_srch_path)
        params_shapes = (nlayers, nqubits, NUM_ROT)
        QModel = QuantumModel(circuits[args.vqc_type], qml_dev, params_shapes, pdict['nclasses'])
        update_model_params(QModel, init_params, device)
        # Define the optimizers
        opt = optimizers[args.opt](QModel.parameters(), lr=args.lr)
        # Setup the device if necessary
        if torch.cuda.is_available() and device == 'cuda':
            torch.cuda.set_device(args.GPU)
            QModel.to(device)
        # Train the VQC
        loss_train, grad_var, _ = train(QModel, opt, train_data_loader, val_data_loader, 
                                        args.epochs, ckpt_srch_path, device)

        # Compute EI and update the best EI accordingly.
        ei = expected_improvement(grad_var, best_grad_var)
        print(f"EI: {ei}")
        if ei > lb:
            init_params_new = {f'layer{i}': torch.round(params.data, decimals=4).tolist()
                               for i, params in enumerate(QModel.parameters())}
            pdict['feedback'] = f"The previous initial model parameter is {init_params_new}. "\
                    f"The model's gradient variance, {grad_var}, is higher than "\
                    f"the historical largest gradient variance, {best_grad_var}."
            best_grad_var = grad_var
            # Store the initial parameters of VQCs with higher grad var.
            best_init_params[f'Srch{search}'] = (ei, grad_var, init_params_new)
            best_loss_curves[f'Srch{search}'] = loss_train
        else:
            pdict['feedback'] = "Expected improvement is not updated in this search."
        print(pdict['feedback'])
        # Update the prompts
        prompts = generate_prompts(pdict)

        # Evaluation
        print("Evaluation:")
        QModel_path = f"{ckpt_srch_path}/model_best.pth.tar" # for eval
        evaluation(QModel, opt, QModel_path, test_data_loader, device)
    return best_init_params, best_loss_curves


def generate_prompts(pdict: dict):
    """Generate prompts given arguments."""
    return f"""
        Role: data generator.
        Goal: Generate a dictionary iteratively with the following shape:
        "{{'layer0': a list, shape=(nlayers, nqubits, NUM_ROT),
          'layer1': a list, shape=(out_dim, nqubits),
          'layer2': a list, shape=(out_dim),
         }}"
        Requirements:
        * Data shape: nlayers={pdict['nlayers']}, nqubits={pdict['nqubits']}, NUM_ROT={pdict['nrot']}, out_dim={pdict['nclasses']}.
        * Data type: float, rounded to four decimals.
        * Data distribution: numerical numbers in each list are sampled from standard {pdict['init']} distributions, which may be modeled from the following dataset.
        * Dataset description: {pdict['data_card']}
        * Adjust the sampling based on feedback from the previous searches: {pdict['feedback']}
        * Crucially, ensure that the length of 'layer0' = 'nlayers' and the length of 'layer1' = 'out_dim'.
        * Print out a dictionary [only] (Don't show python code OR include [```python\\n], [```json\\n], [\\n```]).
        """


if __name__ == '__main__':
    print("args: ", args)
    # Load dataset and preprocessing
    print("Load dataset and preprocessing ...")
    if args.data_name in ['iris', 'wine', 'titanic']:
        data, labels = load_dataset(args.data_name, data_path, file_name[args.data_name])
        SPLIT_RATIO = TRAIN_RATIO_WINE if args.data_name == 'wine' else TRAIN_RATIO  # wine: 80/20/30
        X_train, X_test, y_train, y_test = train_test_split(data, labels, train_size=SPLIT_RATIO)
        if args.data_name == 'titanic':
            X_train, y_train = preprocessing_titanic(X_train, y_train, class_ids, N_SAMPLES)
            X_test, y_test = preprocessing_titanic(X_test, y_test, class_ids, N_SAMPLES)
        # Reduce feature dimension s.t. nfeat = nqubits.
        if len(X_train.shape) == 2 and X_train.shape[1] > args.nqubits:
            X_train = TSNE(n_components=args.nqubits, random_state=SEED_ID, 
                           n_jobs=2, method='exact').fit_transform(X_train)
        if len(X_test.shape) == 2 and X_test.shape[1] > args.nqubits:
            X_test = TSNE(n_components=args.nqubits, perplexity=X_test.shape[0]-5, 
                          random_state=SEED_ID, n_jobs=2, method='exact').\
                            fit_transform(X_test)
    if args.data_name == 'mnist':
        train_data, test_data = load_dataset(args.data_name, data_path, file_name[args.data_name])
        RSIZE = int(np.sqrt(args.nqubits))
        X_train, y_train = preprocessing_mnist(train_data, class_ids, N_SAMPLES, RSIZE)
        X_test, y_test = preprocessing_mnist(test_data, class_ids, N_SAMPLES, RSIZE)
    # Apply label encoding if the label ids don't follow the order.
    if max(class_ids)+1 != len(class_ids):
        label_encoder = LabelEncoder()
        y_train, y_test = label_encoder.fit_transform(y_train), label_encoder.fit_transform(y_test)
    # Apply one-hot encoding to the encoded labels
    onehot_encoder = OneHotEncoder(sparse_output=False)
    y_train, y_test = y_train.reshape(-1, 1), y_test.reshape(-1, 1)
    y_train, y_test = onehot_encoder.fit_transform(y_train), onehot_encoder.fit_transform(y_test)
    # Split the validation set from the train data.
    SPLIT_RATIO = VAL_RATIO_IRIS if args.data_name == 'iris' else TRAIN_RATIO  # iris: 60/20/20
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, train_size=SPLIT_RATIO)
    # Convert to torch tensor
    X_train, X_val, X_test = torch.DoubleTensor(X_train), torch.DoubleTensor(X_val), torch.DoubleTensor(X_test)
    y_train, y_val, y_test = torch.DoubleTensor(y_train), torch.DoubleTensor(y_val), torch.DoubleTensor(y_test)
    # Build dataloaders for batch training
    train_data_loader = data_batch_loader(X_train, y_train, args.batch_size)
    val_data_loader = data_batch_loader(X_val, y_val, args.batch_size)
    test_data_loader = data_batch_loader(X_test, y_test, args.batch_size)
    # Update arguments for generating prompts
    pdict['data_card'] = f"We are using the {args.data_name} dataset. "\
        f"The shape of train data is ({X_train.shape[0]}, {X_train.shape[1]}). "\
        f"The number of clusters is {len(class_ids)}."
    pdict['feedback'] = "N/A"
    # Generate the effective initial parameters for VQCs using a GenAI model.
    opt_init_params, opt_loss = generate_init_params(train_data_loader, val_data_loader, test_data_loader,
                                                   args.nsearches, args.nlayers, args.nqubits,
                                                   pdict, args.genai_name, genai_config, 
                                                   checkpoint_path, args.device)
    ei_dict = {k:v[0] for k, v in opt_init_params.items()}
    print(f"EI: {ei_dict}")
    gvar_dict = {k:v[1] for k, v in opt_init_params.items()}
    print(f"GVar: {gvar_dict}")
    # Save files
    save_files(init_params_path, opt_init_params)
    save_files(loss_path, opt_loss)
    torch.cuda.empty_cache()
