import os
import json
import torch
import numpy as np
import circuit.var_config as vc
import torch.nn.functional as F

current_path = os.getcwd()

def load_json(f_name):
    """load circuit dataset."""
    file_path = os.path.join(current_path, f'circuit\\data\\{f_name}')
    with open(f_name, 'r') as file:
        dataset = json.loads(file.read())
    return dataset

def transform_operations(max_idx):
    transform_dict =  {0:'START', 1:'PauliX', 2:'PauliY', 3:'PauliZ', 4:'Hadamard', 5:'RX', 
                       6:'RY', 7:'RZ', 8:'CNOT', 9:'CY', 10:'CZ', 11:'U3', 12:'END'}
    ops = []
    for idx in max_idx:
        ops.append(transform_dict[idx.item()])
    return ops

def save_checkpoint(model, optimizer_vae, epoch, loss, dim, name, dropout, seed):
    """Saves a checkpoint."""
    # Record the state
    checkpoint = {
        'epoch': epoch,
        'loss': loss,
        'model_state': model.state_dict(),
        'optimizer_vae_state': None if optimizer_vae == None else optimizer_vae.state_dict(),
    }
    # Write the checkpoint
    dir_name = 'pretrained/dim-{}'.format(dim)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    f_path = os.path.join(dir_name, 'model-ae-{}.pt'.format(name))
    torch.save(checkpoint, f_path)


def save_checkpoint_vae(model, optimizer_vae, epoch, loss, dim, name, dropout, seed):
    """Saves a checkpoint."""
    # Record the state
    checkpoint = {
        'epoch': epoch,
        'loss': loss,
        'model_state': model.state_dict(),
        'optimizer_vae_state': None if optimizer_vae == None else optimizer_vae.state_dict(),
    }
    # Write the checkpoint
    dir_name = 'pretrained/dim-{}'.format(dim)
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    f_path = os.path.join(dir_name, 'model-{}.pt'.format(name))
    torch.save(checkpoint, f_path)

def normalize_adj(A):
    D_in = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=1)))
    D_out = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=2)))
    DA = stacked_spmm(D_in, A)  # swap D_in and D_out
    DAD = stacked_spmm(DA, D_out)
    return DAD

def preprocessing(A, H, method, lbd=None):
    # FixMe: Attention multiplying D or lbd are not friendly with the crossentropy loss in GAE
    assert A.dim()==3

    if method == 0:
        def prep_reverse(A, H):
            return A.triu(1), H
        return A, H, prep_reverse

    elif method == 1:
        # Adding global node with padding
        A = F.pad(A, (0,1), 'constant', 1.0)
        A = F.pad(A, (0,0,0,1), 'constant', 0.0)
        H = F.pad(H, (0,1,0,1), 'constant', 0.0 )
        H[:, -1, -1] = 1.0

    elif method == 2:
        # using A^T instead of A
        # and also adding a global node
        A = A.transpose(-1, -2)
        D_in = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=1)))
        D_out = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=2)))
        DA = stacked_spmm(D_in, A) # swap D_in and D_out
        DAD = stacked_spmm(DA, D_out)
        return DAD, H

    elif method == 3:
        assert lbd!=None
        # using lambda*A + (1-lambda)*A^T
        A = lbd * A + (1-lbd)*A.transpose(-1, -2)
        D_in = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=1)))
        D_out = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=2)))
        DA = stacked_spmm(D_in, A)  # swap D_in and D_out
        DAD = stacked_spmm(DA, D_out)
        def prep_reverse(DAD, H):
            AD = stacked_spmm(1.0/D_in, DAD)
            A =  stacked_spmm(AD, 1.0/D_out)
            return A.triu(1), H
        return DAD, H, prep_reverse

    elif method == 4:
        # bidirectional DAG
        assert lbd != None
        # using lambda*A + (1-lambda)*A^T
        A = lbd * A + (1 - lbd) * A.transpose(-1, -2)
        def prep_reverse(A, H):
            return 1.0/lbd*A.triu(1), H
        return A, H, prep_reverse

    elif method == 5:
        A = A + A.triu(1).transpose(-1, -2)
        def prep_reverse(A, H):
            return A.triu(1), H
        return A, H, prep_reverse
    

def get_accuracy(inputs, targets, model_flag='gsqas'):
    N, I, _ = inputs[0].shape
    full_ops_recon, adj_recon = inputs[0], inputs[1]
    full_ops, adj = targets[0], targets[1]
    ops_recon = full_ops_recon[:,:,:-vc.num_qubits]
    ops_qubits_recon = full_ops_recon[:, :, -vc.num_qubits:]
    ops = full_ops[:,:,:-vc.num_qubits]
    ops_qubits = full_ops[:, :, -vc.num_qubits:]
    # post processing, assume non-symmetric
    adj_recon, adj = adj_recon.triu(1), adj.triu(1)
    ad_decode = None
    correct_ops = ops_recon.argmax(dim=-1).eq(ops.argmax(dim=-1)).float().mean().item()
    mean_correct_adj = adj_recon.type(torch.bool)[adj.type(torch.bool)].float().sum().item() / adj.type(torch.bool).float().sum()
    mean_false_positive_adj = adj_recon.type(torch.bool)[(~adj.type(torch.bool)).triu(1)].sum().item() / (N*I*(I-1)/2.0-adj.type(torch.bool).float().sum())
    if model_flag == "gsqas":
        correct_ops_qubits = torch.all((ops_qubits_recon > 0.5).eq(ops_qubits == 1), dim=-1).float().mean().item()
        ad_decode = torch.clip(adj_recon, 0, 1)
        ad_decode = ad_decode.round().detach()
    else:
        correct_ops_qubits = torch.logical_and(torch.all((ops_qubits_recon > 0.5).eq(ops_qubits == 1), dim=-1), torch.all((ops_qubits_recon < -0.5).eq(ops_qubits == -1), dim=-1)).float().mean().item()
        ad_decode = adj_recon.triu(1)
        ad_decode = torch.clip(ad_decode, 0, 2)
        ad_decode = ad_decode.round().detach()
    correct_adj = ad_decode.eq(adj).float().triu(1).sum().item()/ (N*I*(I-1)/2.0)

    #ops_correct = ops_recon.argmax(dim=-1).eq(ops.argmax(dim=-1)).float()
    #adj_correct = adj_recon_thre.eq(adj.type(torch.bool)).float()
    return correct_ops, correct_ops_qubits, mean_correct_adj, mean_false_positive_adj, correct_adj


def get_train_acc(inputs, targets, model_flag):
    acc_train = get_accuracy(inputs, targets, model_flag)
    return 'training batch: acc_ops:{0:.4f}, mean_corr_adj:{1:.4f}, mean_fal_pos_adj:{2:.4f}, acc_adj:{3:.4f}'.format(*acc_train)


def get_val_acc_vae(model, cfg, X_adj, X_ops, indices, model_flag):
    model.eval()
    bs = 500
    chunks = len(X_adj) // bs
    if len(X_adj) % bs > 0:
        chunks += 1
    X_adj_split = torch.split(X_adj, bs, dim=0)
    X_ops_split = torch.split(X_ops, bs, dim=0)
    indices_split = torch.split(indices, bs, dim=0)
    correct_ops_ave, correct_ops_qubits_ave, mean_correct_adj_ave, mean_false_positive_adj_ave, correct_adj_ave = 0, 0, 0, 0, 0
    for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)):
        adj, ops = adj.cuda(), ops.cuda()
        # preprocessing
        adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
        # forward
        ops_recon, adj_recon, mu, logvar = model.forward(ops, adj)
        # reverse preprocessing
        adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon)
        adj, ops = prep_reverse(adj, ops)
        correct_ops, correct_ops_qubits, mean_correct_adj, mean_false_positive_adj, correct_adj = get_accuracy((ops_recon, adj_recon), (ops, adj), model_flag)
        correct_ops_ave += correct_ops * len(ind)/len(indices)
        correct_ops_qubits_ave += correct_ops_qubits * len(ind)/len(indices)
        mean_correct_adj_ave += mean_correct_adj * len(ind)/len(indices)
        mean_false_positive_adj_ave += mean_false_positive_adj * len(ind)/len(indices)
        correct_adj_ave += correct_adj * len(ind)/len(indices)
    return correct_ops_ave, correct_ops_qubits_ave, mean_correct_adj_ave, mean_false_positive_adj_ave, correct_adj_ave

def stacked_mm(A, B):
    assert A.dim()==3
    assert B.dim()==3
    return torch.matmul(A, B)

def stacked_spmm(A, B):
    assert A.dim()==3
    assert B.dim()==3
    return torch.matmul(A, B)

def is_valid_circuit(adj, ops, ops_qubits, model_flag):
    # check length
    adj = torch.tensor(adj)
    if len(adj) != len(ops) or len(adj[0]) != len(ops):
        return False
    # check operations
    if ops[0] != 'START' or ops[-1] != 'END':
        return False
    if not (torch.all(ops_qubits[0, :-1] == 1) and torch.all(ops_qubits[-1, :-1] == 1)):
        return False
    for i in range(1, len(ops)-1):
        if ops[i] not in vc.allowed_gates:
            return False
        elif ops[i] in ['CNOT', 'CY', 'CZ']:
            if torch.count_nonzero(ops_qubits[i]).item() != 2:
                print(torch.count_nonzero(ops_qubits[i]).item())
                print(ops_qubits[i])
                return False
        else:
            if torch.count_nonzero(ops_qubits[i]).item() != 1:
                print(torch.count_nonzero(ops_qubits[i]).item())
                print(ops_qubits[i])
                return False
    # check if qubits and adj match
    for i in range(len(adj[0])):
        for j in range(i, len(adj[1])):
            if adj[i][j] != 0 and len(np.intersect1d(ops_qubits[i], ops_qubits[j])) == 0:
                return False
    in_degree = adj.sum(dim=0)
    out_degree = adj.sum(dim=1)
    #print(f"indegree: {in_degree}")
    #print(f"outdegree: {out_degree}")
    if model_flag == 'gsqas':
        # check the node 'START'
        if in_degree[0] != 0:
            return False
        if out_degree[0] > vc.num_qubits or out_degree[i] < 2:
            return False
        else:
            count = 0
            for j in range(0, len(ops)):
                if adj[0][j] == 1:
                    if ops[j] in ["CNOT", "CY", "CZ"] and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j])) == 2:
                        count += 1
            if count != 4 - out_degree[0]:
                return False
        # check the node 'END'
        if out_degree[-1] != 0:
            return False
        if in_degree[-1] > vc.num_qubits or in_degree[-1] < 2:
            return False
        else:
            count = 0
            for j in range(0, len(ops)):
                if adj[j][-1] == 1:
                    if ops[j] in ["CNOT", "CY", "CZ"] and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j])) == 2:
                        count += 1
            if count != 4 - in_degree[-1]:
                return False
        # check real gate nodes
        for i in range(1, len(ops)-1):
            if ops[i] in ["CNOT", "CY", "CZ"]:
                if in_degree[i] > 2 or in_degree[i] < 1:
                    return False
                else:
                    count = 0
                    for j in range(i):
                        if adj[j][i] == 1:
                            if ops[j] in ["CNOT", "CY", "CZ"] and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j], dim=-1)) == 2:
                                count += 1
                    if count != 2 - in_degree[i]:
                        return False
                if out_degree[i] > 2 or out_degree[i] < 1:
                    return False
                else:
                    count = 0
                    for j in range(i, len(ops)):
                        if adj[i][j] == 1:
                            if ops[j] in ["CNOT", "CY", "CZ"] and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j], dim=-1)) == 2:
                                count += 1
                    if count != 2 - out_degree[i]:
                        return False
            else:
                if in_degree[i] != 1 or out_degree[i] != 1:
                    return False
                else:
                    for j in range(i):
                        if adj[j][i] == 1 and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j])) != 1:
                            return False
                    for j in range(i, len(ops)):
                        if adj[i][j] == 1 and torch.count_nonzero(ops_qubits[i].eq(ops_qubits[j])) != 1:
                            return False
    else:
        # check the node 'START'
        if in_degree[0] != 0 or out_degree[0] != vc.num_qubits:
            return False
        # check the node 'END'
        if in_degree[-1] != vc.num_qubits or out_degree[-1] != 0:
            return False
        # check real gate nodes
        for i in range(1, len(ops)-1):
            if ops[i] in ["CNOT", "CY", "CZ"]:
                if in_degree[i] != 2 or out_degree[i] != 2:
                    return False
            else:
                if in_degree[i] != 1 or out_degree[i] != 1:
                    return False
        # check qubits
        for i in range(1, len(ops)):
            for j in range(i):
                if adj[j][i] == 1 and connected_qubit_count(ops_qubits[i], ops_qubits[j]) != 1:
                    return False
                if adj[j][i] == 2 and connected_qubit_count(ops_qubits[i], ops_qubits[j]) != 2:
                    return False
        for i in range(len(ops)-1):
            for j in range(i, len(ops)):
                if adj[i][j] == 1 and connected_qubit_count(ops_qubits[i], ops_qubits[j]) != 1:
                    return False
                if adj[i][j] == 2 and connected_qubit_count(ops_qubits[i], ops_qubits[j]) != 2:
                    return False
    return True

def connected_qubit_count(ops_qubit_1, ops_qubit_2):
    count = 0
    for i, q in enumerate(ops_qubit_1):
        if q != 0 and ops_qubit_2[i] != 0:
            count += 1
    return count
            