import numpy as np
import torch
from torch.utils.data import TensorDataset
from sklearn.metrics import accuracy_score

from config import (
    TRAIN_DATA_PATH,
    TEST_DATA_PATH,
    PCA_DIM,
    SEED,
    NQUBIT,
    C_DEPTH,
    MAX_ITER,
    BAG_SIZE,
    START_LR,
    USE_GPU,
    TRAINED_CIRCUIT_JSON,
    TRAINED_CIRCUIT_PNG,
)

from qcl_classification import QclClassification
from data_utils import load_pt_features, create_random_bags
from qulacs import QuantumState
from qulacsvis import circuit_drawer


def main():
    state = QuantumState(NQUBIT)
    print(state.get_device_name())

    x_train, x_test, y_train_label, y_test_label = load_pt_features(
        TRAIN_DATA_PATH, TEST_DATA_PATH, PCA_DIM
    )

    num_class = len(np.unique(y_train_label))
    np.random.seed(SEED)

    train_ds = TensorDataset(
        torch.tensor(x_train, dtype=torch.float32),
        torch.tensor(y_train_label, dtype=torch.long),
    )

    bag_sampler, teacher_probs = create_random_bags(
        train_ds, BAG_SIZE, num_class, shuffle=True, seed=SEED
    )
    print("Bag shape(teacher proportion)",teacher_probs.shape)

    qcl = QclClassification(NQUBIT, C_DEPTH, num_class)
    qcl.fit_llp_inner_product(
        x_train,
        bag_sampler,
        teacher_probs,
        n_iter=MAX_ITER,
        lr=START_LR,
        loss="ce",
        n_jobs=2,
    )

    pred_train = qcl.pred_amplitude(x_train)
    acc_train = accuracy_score(y_train_label, np.argmax(pred_train, axis=1))

    pred_test = qcl.pred_amplitude(x_test)
    acc_test = accuracy_score(y_test_label, np.argmax(pred_test, axis=1))

    print(f"train accuracy: {acc_train:.3f}")
    print(f"test accuracy: {acc_test:.3f}")

    # Save trained circuit and draw as PNG using config-defined filenames
    circuit_json = qcl.output_gate.to_json()
    with open(TRAINED_CIRCUIT_JSON, "w") as f:
        f.write(circuit_json)
    circuit_drawer(
        qcl.output_gate, output_method="mpl", filename=TRAINED_CIRCUIT_PNG
    )
    print(
        f"Saved trained circuit to {TRAINED_CIRCUIT_JSON} and {TRAINED_CIRCUIT_PNG}"
    )

if __name__ == "__main__":
    print(
    "TRAIN_DATA_PATH", TRAIN_DATA_PATH,
    "\nTEST_DATA_PATH",TEST_DATA_PATH,
    "\nPCA_DIM", PCA_DIM,
    "\nSEED", SEED,
    "\nNQUBIT", NQUBIT,
    "\nC_DEPTH", C_DEPTH,
    "\nMAX_ITER", MAX_ITER,
    "\nBAG_SIZE", BAG_SIZE,
    "\nSTART_LR", START_LR,)
    main()
