from models import *
from models.circuits import weight_dict, qubit_dict, depth_dict
import torch
import os
import pennylane as qml


def load_model(conf):
    model_type = conf.structure
    class_idx = conf.class_idx
    num_classes = len(class_idx)

    print(f'model: {model_type} of {conf.version},'
          f'device: finite sampling - {conf.finite}, noise - {conf.noise}')
    n_qubits = qubit_dict[model_type]
    if conf.noise:
        if conf.finite == 0:
            dev = qml.device("lightning.qubit", wires=n_qubits)
        else:
            dev = qml.device("lightning.qubit", wires=n_qubits, shots=conf.finite)
    elif conf.finite != 0:  # noise is False
        dev = qml.device("lightning.qubit", wires=n_qubits, shots=conf.finite)
    else:
        dev = qml.device("default.qubit", wires=n_qubits)
    print(dev)

    if model_type == 'qcl':
        model = QCL(dev, num_classes=num_classes, noise=conf.noise, embedding=conf.encoding)
    elif model_type == 'qcnn':
        model = QCNN(dev, num_classes=num_classes, noise=conf.noise, embedding=conf.encoding)
    elif model_type == 'drnn':
        conf.encoding = 'interleaved'
        model = DRNN(dev, num_classes=num_classes, scaling=1.5, ent_train=True, noise=conf.noise)
    elif model_type == 'hqnn':
        data_size = (1, 16, 16) if conf.resize else (1, 28, 28)
        model = HQNN(dev, n_features=int(torch.tensor(data_size).prod()), num_classes=num_classes, noise=conf.noise, embedding=conf.encoding)
    return model


def load_params_from_path(conf, dev_conf=None):
    model_n = conf.structure
    model_depth = depth_dict[model_n]
    n_qubits = qubit_dict[model_n]

    noise = dev_conf['noise'] if dev_conf is not None else conf.noise
    finite = dev_conf['finite'] if dev_conf is not None else conf.finite

    if conf.structure == 'drnn':
        conf.encoding = 'interleaved'
    else:
        conf.encoding = 'amplitude'

    if conf.resize:
        mode_path = os.path.join(conf.model_dir, conf.dataset, conf.version, model_n,
                                 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + conf.reduction + '_' + str(conf.class_idx) + '_sample_' + str(finite) + '_noise_' + str(noise) + '_depth_' + str(model_depth) + '.pth')
    else:
        mode_path = os.path.join(conf.model_dir, conf.dataset, conf.version, model_n,
                                 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + str(conf.class_idx) + '_sample_' + str(finite) + '_noise_' + str(noise) + '_depth_' + str(model_depth) + '.pth')

    print(f'load model from: {mode_path}...')
    model = load_model(conf=conf)

    model.load_state_dict(torch.load(mode_path))
    model.eval()
    state_dict = model.state_dict()

    weight_name = weight_dict[conf.structure]
    if type(weight_name) is list:
        params = []
        for i in range(len(weight_name)):
            params.append(state_dict[weight_name[i]])
    else:
        params = state_dict[weight_name]
    return params, model


def load_train_params(structure, state_dict):
    weight_name = weight_dict[structure]
    if type(weight_name) is list:
        params = []
        for i in range(len(weight_name)):
            params.append(state_dict[weight_name[i]])
    else:
        params = state_dict[weight_name]
    return params


def load_circuit_structure(conf, dev, noise=False):
    c = None
    structure = conf.structure
    if structure == 'qcl':
        c = QCL.create_circuit(dev, noise=noise, embedding=conf.encoding)
    elif structure == 'qcnn':
        c = QCNN.create_circuit(dev, noise=noise, embedding=conf.encoding)
    elif structure == 'drnn':
        c = DRNN.create_circuit(dev, noise=noise, ent_train=True)
    elif structure == 'hqnn':
        c = HQNN.create_circuit(dev, noise=noise, embedding=conf.encoding)
    return c


def cidx2qidx(model_n, cidx, num_qubits=8, class_idx=[0, 1]):
    # idx of classification -> qubit idx when designing circuit outputs
    if model_n in ['qcl']:
        return cidx
    elif model_n == 'qcnn':
        if len(class_idx) == 2:
            q_idx = [0, 2] if num_qubits == 8 else [0, 4]
        elif len(class_idx) == 3:
            q_idx = [0, 2, 4] if num_qubits == 8 else [0, 4, 8]
        return q_idx[cidx]