import numpy as np
import pandas as pd
import os, json
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import time
import random
import torch
from scipy.special import logsumexp

from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit.quantum_info import SparsePauliOp
from qiskit.circuit.library import RealAmplitudes
from qiskit.transpiler import generate_preset_pass_manager, PassManager
from qiskit.transpiler import InstructionProperties
from qiskit.visualization import plot_distribution
from qiskit.providers.fake_provider import GenericBackendV2
from qiskit.primitives import StatevectorEstimator
from qiskit.circuit import ParameterVector
from qiskit.providers import BackendV2 as Backend

from qiskit_aer import AerSimulator
from qiskit_ibm_runtime import (
    QiskitRuntimeService, 
    EstimatorV2 as Estimator,
    SamplerV2 as Sampler,
    EstimatorOptions,
    Session,
)

from data import load_dataset
from qlms import Adam, HybridModel, HybridQRNNLanguageModel, HybridQCNNLanguageModel, train_hybrid_model, evaluate_perplexity_hybrid, get_model
from args_factory import get_args

def setup_qiskit_ibm_runtime():
    your_api_key = "<api_key_here>"
    CRN = "<crn_here>"

    from qiskit_ibm_runtime import QiskitRuntimeService

    QiskitRuntimeService.save_account(
        channel="ibm_cloud",
        token=your_api_key,
        instance=CRN,
        name="qlms",
        overwrite=True
    )

    service = QiskitRuntimeService(name="qlms")
    return service

# def parse_model_path(args):
#     # Example: ./models/TS-LM/qrnn/emb3_seq6_lr0.1_PGPE_pop8_sigma0.05_PGPE_aer_simulator_seed42/
#     # Example: ./models/TS-LM/qcnn/type_33_emb3_seq6_reps2_lr0.1_PGPE_pop8_sigma0.05_aer_simulator_seed42/
#     if args.model == 'qrnn':
#         path = f"{args.model_path}/{args.dataset}/{args.model}/{args.backend}_shots{args.shots}/emb{args.emb_size}_seq{args.seq_len}/{args.alg}_lr{args.lr}_BS{args.batch_size}_EP{args.epochs}_pop{args.popsize}_sigma{args.sigma}_seed{args.seed}/"
#     elif args.model == 'qcnn':
#         path = f"{args.model_path}/{args.dataset}/{args.model}/{args.backend}_shots{args.shots}/type_{args.cnn_type}_emb{args.emb_size}_seq{args.seq_len}_reps{args.reps}/{args.alg}_lr{args.lr}_BS{args.batch_size}_EP{args.epochs}_pop{args.popsize}_sigma{args.sigma}_seed{args.seed}/"
#     return path

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    
def main():
    args = get_args()
    set_seed(args.seed)

    train_loader, test_loader, seq_len, vocab_size, task = load_dataset(args)
    args.seq_len = seq_len
    args.vocab_size = vocab_size
    args.task = task

    # service = setup_qiskit_ibm_runtime() # only needed once
    service = QiskitRuntimeService(name="fmc-eu")
    if args.backend == 'aer_simulator':
        backend = AerSimulator(seed_simulator=args.seed)
    else:
        backend = service.backend(args.backend)
        
    print(f"Using backend: {backend.name}")

    estimator = Estimator(mode=backend)
    # estimator.options.default_shots = 1024
    if args.shots is not None:
        estimator.options.default_shots = args.shots
    pm = generate_preset_pass_manager(
        backend=backend,
        initial_layout=None,
        optimization_level=1
    )

    model = get_model(args, backend, estimator, pm)

    model.load(args.load_model)

    start_time = time.time()
    history = {}
    print("Starting eval...")
    # train_ppl, train_acc = evaluate_perplexity_hybrid(model, train_loader, args.max_batches)
    test_ppl, test_acc = evaluate_perplexity_hybrid(model, test_loader, args.max_batches)
    # train_ppl = float(train_ppl)
    # train_acc = float(train_acc)
    # train_loss = np.log(train_ppl)
    test_ppl = float(test_ppl)
    test_acc = float(test_acc)
    test_loss = np.log(test_ppl)
    end_time = time.time()
    print(f"Eval completed in {(end_time - start_time)/60:.2f} minutes.")
    # print(f"Train Loss: {train_loss:.4f}, Perplexity: {train_ppl:.4f}, Accuracy: {train_acc:.4f}")
    print(f"Test  Loss: {test_loss:.4f}, Perplexity: {test_ppl:.4f}, Accuracy: {test_acc:.4f}")
    # history['train_loss'] = train_loss
    # history['train_ppl'] = train_ppl
    # history['train_acc'] = train_acc
    history['test_loss'] = test_loss
    history['test_ppl'] = test_ppl
    history['test_acc'] = test_acc
    history['eval_time'] = end_time - start_time

    save_dir = os.path.dirname(args.load_model)
    model_name = os.path.basename(args.load_model).split('.')[0]
    filename = f"{save_dir}/eval_{model_name}_{args.backend}_shots{args.shots}.json"
    with open(filename, 'w') as f:
        json.dump(history, f, indent=4)
    print(f"Eval results saved to {filename}")

if __name__ == "__main__":
    main()