import os
import sys
sys.path.insert(0, os.getcwd())
import json
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import circuit.var_config as vc

from torch import optim
from models.configs import configs
from nasbench.lib import graph_util
from models.model import Model, VAEReconstructed_Loss, GVAE, Reconstructed_Loss, DegreeConsistencyLoss
from utils.utils import load_json, save_checkpoint_vae, preprocessing
from utils.utils import get_val_acc_vae, is_valid_circuit, transform_operations

torch.autograd.set_detect_anomaly(True)

def _build_dataset(dataset, list, model_flag='gsqas'):
    indices = np.random.permutation(list)
    X_adj = []
    X_ops = []
    X_indegree = []
    X_outdegree = []
    op_flag = None
    adj_flag = None
    if model_flag == 'gsqas':
        op_flag = 'gate_matrix'
        adj_flag = 'adj_matrix'
    elif model_flag in ["quantum_arch2vec", "quantum_arch2vec_with_degree"]:
        op_flag = 'improved_gate_matrix'
        adj_flag = 'adj_matrix_with_degree'
    else:
        raise ValueError("Invalid model, only support gsqas, quantum_arch2vec and quantum_arch2vec_with_degree.")
    for ind in indices:
        X_adj.append(torch.Tensor(dataset[ind][adj_flag]))
        X_ops.append(torch.Tensor(dataset[ind][op_flag]))
        X_indegree.append(torch.Tensor(dataset[ind]['indegree']).unsqueeze(1))
        X_outdegree.append(torch.Tensor(dataset[ind]['outdegree']).unsqueeze(1))
    X_adj = torch.stack(X_adj)
    X_ops = torch.stack(X_ops)
    X_indegree = torch.stack(X_indegree)
    X_outdegree = torch.stack(X_outdegree)
    return X_adj, X_ops, X_indegree, X_outdegree, torch.Tensor(indices)


def pretraining_model(dataset, cfg, args):
    beta_kl = args.beta_kl
    beta_degree = args.beta_degree
    train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset))
    X_adj_train, X_ops_train, X_indegree_train, X_outdegree_train, indices_train = _build_dataset(dataset, train_ind_list, args.model_flag)
    X_adj_val, X_ops_val, X_indegree_val, X_outdegree_val, indices_val = _build_dataset(dataset, val_ind_list, args.model_flag)
    model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.dim,
                   num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, model_flag=args.model_flag, **cfg['GAE']).cuda()
    # model = GVAE((args.input_dim, 32, 64, 128, 64, 32, args.dim), normalize=True, dropout=args.dropout, **cfg['GAE']).cuda()
    optimizer_vae = optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)
    epochs = args.epochs
    bs = args.bs
    recon_loss_total = []
    KLD_loss_total = []
    degree_loss_total = []
    loss_total = []
    loss_dict = dict()
    for epoch in range(0, epochs):
        chunks = len(train_ind_list) // bs
        if len(train_ind_list) % bs > 0:
            chunks += 1
        X_adj_split = torch.split(X_adj_train, bs, dim=0)
        X_ops_split = torch.split(X_ops_train, bs, dim=0)
        X_indegree_split = torch.split(X_indegree_train, bs, dim=0)
        X_outdegree_split = torch.split(X_outdegree_train, bs, dim=0)
        indices_split = torch.split(indices_train, bs, dim=0)
        recon_loss_epoch = []
        KLD_loss_epoch = []
        degree_loss_epoch = []
        total_loss_epoch = []
        Z = []
        if epoch > 0:
            beta_kl = min(beta_kl + 0.1, 1)
            beta_degree = min(beta_degree + 0.1, 0.5)
        model.train()
        for i, (adj, ops, indegree, outdegree, ind) in enumerate(zip(X_adj_split, X_ops_split, X_indegree_split, X_outdegree_split, indices_split)):
            adj, ops, indegree, outdegree = adj.cuda(), ops.cuda(), indegree.cuda(), outdegree.cuda()
            # preprocessing
            adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
            # forward
            #degree_features = torch.cat([indegree, outdegree], dim=-1)
            ops_recon, adj_recon, mu, logvar = model(ops, adj)
            Z.append(mu)
            adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon)
            adj, ops = prep_reverse(adj, ops)
            recon_loss, KLD = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar) # With KL
            actual_in_degree = adj_recon.sum(dim=-2)
            actual_out_degree = adj_recon.sum(dim=-1)
            in_degree_loss = F.mse_loss(actual_in_degree, indegree.squeeze(-1).float())
            out_degree_loss = F.mse_loss(actual_out_degree, outdegree.squeeze(-1).float())
            degree_loss = in_degree_loss + out_degree_loss
            degree_loss_epoch.append(degree_loss.item())
            # loss = Reconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj)) # Without KL
            if args.model_flag == "quantum_arch2vec":
                total_loss = recon_loss + beta_kl * KLD
                optimizer_vae.zero_grad()
                total_loss.backward()
                nn.utils.clip_grad_norm_(model.encoder.parameters(), 5)
                nn.utils.clip_grad_norm_(model.decoder.parameters(), 5)
                optimizer_vae.step()
                recon_loss_epoch.append(recon_loss.item())
                KLD_loss_epoch.append(KLD.item())
                total_loss_epoch.append(total_loss.item())
            elif args.model_flag == "gsqas":
                total_loss = recon_loss + KLD
                optimizer_vae.zero_grad()
                total_loss.backward()
                nn.utils.clip_grad_norm_(model.encoder.parameters(), 5)
                nn.utils.clip_grad_norm_(model.decoder.parameters(), 5)
                optimizer_vae.step()
                recon_loss_epoch.append(recon_loss.item())
                KLD_loss_epoch.append(KLD.item())
                total_loss_epoch.append(total_loss.item())
            elif args.model_flag == "quantum_arch2vec_with_degree":
                degree_consistency_loss = DegreeConsistencyLoss()(indegree, outdegree, adj_recon)
                total_loss = recon_loss + beta_kl * KLD + beta_degree * degree_consistency_loss
                optimizer_vae.zero_grad()
                total_loss.backward()
                nn.utils.clip_grad_norm_(model.encoder.parameters(), 5)
                nn.utils.clip_grad_norm_(model.decoder.parameters(), 5)
                optimizer_vae.step()
                recon_loss_epoch.append(recon_loss.item())
                KLD_loss_epoch.append(KLD.item())
                total_loss_epoch.append(total_loss.item())
            
            if i%100 == 0:
                print('epoch {}: batch {} / {}: recon_loss: {:.5f}, KLD: {: .5f}, degree_loss:{: .5f}, total_loss: {:.5f}' \
                    .format(epoch, i, chunks, recon_loss.item(), KLD.item(), degree_loss.item(), total_loss.item()))
        
        Z = torch.cat(Z, dim=0)
        z_mean, z_std = Z.mean(0), Z.std(0)
        validity_counter = 0
        buckets = {}
        model.eval()
        for _ in range(args.latent_points):
            z = torch.randn(X_adj_val[0].shape[0], args.dim).cuda() #
            z = z * z_std + z_mean
            full_op, full_ad = model.decoder(z.unsqueeze(0))
            full_op = full_op.squeeze(0).cpu()
            ad = full_ad.squeeze(0).cpu()
            op, op_qubits = None, None
            op = full_op[:, :-vc.num_qubits]
            op_qubits = full_op[:, -vc.num_qubits:]
            op_max_idx = torch.argmax(op, dim=-1)
            one_hot = torch.zeros_like(op)
            qubit_decode = torch.zeros_like(op_qubits)
            full_op_encoding = None
            for i in range(one_hot.shape[0]):
                one_hot[i][op_max_idx[i]] = 1
            op_decode = transform_operations(op_max_idx)
            for i in range(qubit_decode.shape[0]):
                if op_decode[i] in ['START', 'END']:
                    qubit_decode[i][:] = 1
                elif op_decode[i] in ['CNOT', 'CY', 'CZ']:
                    if args.model_flag in ["quantum_arch2vec", "quantum_arch2vec_with_degree"]:
                        qubit_control = torch.argmin(op_qubits[i], dim=-1)
                        qubit_target = torch.argmax(op_qubits[i], dim=-1)
                        qubit_decode[i][qubit_control] = -1
                        qubit_decode[i][qubit_target] = 1
                    elif args.model_flag == "gsqas":
                        qubit_indices = torch.topk(op_qubits[i], 2, dim=-1).indices
                        qubit_decode[i][qubit_indices] = 1
                else:
                    qubit_index = torch.argmax(op_qubits[i], dim=-1)
                    qubit_decode[i][qubit_index] = 1
            ad_decode = ad.triu(1)
            ad_decode = torch.clip(ad_decode, 0, 2)
            ad_decode = ad_decode.round().cpu().detach().numpy()
            full_op_encoding = torch.cat((one_hot, qubit_decode), dim=-1)
            ad_decode = np.ndarray.tolist(ad_decode)
            
            if is_valid_circuit(ad_decode, op_decode, qubit_decode, model_flag=args.model_flag):
                validity_counter += 1
                fingerprint = graph_util.hash_module(np.array(ad_decode), full_op_encoding.numpy().tolist())
                if fingerprint not in buckets:
                    buckets[fingerprint] = (ad_decode, full_op_encoding.numpy().astype('int8').tolist())
        validity = validity_counter / args.latent_points
        print('Ratio of valid decodings from the prior: {:.4f}'.format(validity))
        print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8)))
        acc_ops_val, acc_ops_qubits_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val \
            = get_val_acc_vae(model, cfg, X_adj_val, X_ops_val, indices_val, args.model_flag)
        print('validation set: acc_ops:{0:.4f}, acc_ops_qubits_val:{1:.4f}, mean_corr_adj:{2:.4f}, mean_fal_pos_adj:{3:.4f}, acc_adj:{4:.4f}'.format(
                acc_ops_val, acc_ops_qubits_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val))
        print('epoch {}: average loss {:.5f}'.format(epoch, sum(total_loss_epoch)/len(total_loss_epoch)))
        loss_total.append(sum(total_loss_epoch) / len(total_loss_epoch))
        save_checkpoint_vae(model, optimizer_vae, 
                            epoch, sum(total_loss_epoch) / len(total_loss_epoch), args.dim, args.name, args.dropout, args.seed)
        
        recon_loss_total.append(sum(recon_loss_epoch) / len(recon_loss_epoch))
        KLD_loss_total.append(sum(KLD_loss_epoch) / len(KLD_loss_epoch))
        degree_loss_total.append(sum(degree_loss_epoch) / len(degree_loss_epoch))

    loss_dict["recon_loss_total"] = recon_loss_total
    loss_dict["KLD_loss_total"] = KLD_loss_total
    loss_dict["degree_consistency_loss_total"] = degree_loss_total
    save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    print('save to {}'.format(save_path))
    fh = open(os.path.join(save_path, 'model_{}.json'.format(args.name)), 'w')
    json.dump(loss_dict, fh)
    fh.close()    
    print('loss for epochs: \n', loss_total)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Pretraining')
    parser.add_argument("--seed", type=int, default=1, help="random seed")
    parser.add_argument('--data', type=str, default=f'circuit\\data\\data_{vc.num_qubits}_qubits.json',
                        help='Data file (default: data.json')
    parser.add_argument('--name', type=str, default=f'circuits_{vc.num_qubits}_qubits_quantum_arch2vec',
                        help='circuits with correspoding number of qubits')
    parser.add_argument('--cfg', type=int, default=1,
                        help='configuration: 0 for gsqas, 1 for quantum_arch2vec, 2 for quantum_arch2vec_with_degree')
    parser.add_argument('--bs', type=int, default=32,
                        help='batch size (default: 32)')
    parser.add_argument('--epochs', type=int, default=16,
                        help='training epochs (default: 16)')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='decoder implicit regularization (default: 0.1)')
    parser.add_argument('--beta_kl', type=float, default=0.1,
                        help='decoder implicit regularization (default: 0.1)')
    parser.add_argument('--beta_degree', type=float, default=0,
                        help='decoder implicit regularization (default: 0)')
    parser.add_argument('--model_flag', type=str, default='quantum_arch2vec',
                        help='which model wiil be used')
    parser.add_argument('--output_path', type=str, default='saved_logs\\pretraining')
    parser.add_argument('--normalize', action='store_true', default=True,
                        help='use input normalization')
    parser.add_argument('--input_dim', type=int,
                        default=2+len(vc.allowed_gates)+vc.num_qubits)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser.add_argument('--dim', type=int, default=16,
                        help='feature (latent) dimension (default: 16)')
    parser.add_argument('--hops', type=int, default=5)
    parser.add_argument('--mlps', type=int, default=2)
    parser.add_argument('--latent_points', type=int, default=10,
                        help='latent points for validaty check (default: 10000)')

    args = parser.parse_args()
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cfg = configs[args.cfg]
    dataset = load_json(args.data)
    print('using {}'.format(args.data))
    print('feat dim {}'.format(args.dim))
    train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset))
    X_adj_train, X_ops_train, X_indegree, X_outdegree, indices_train = _build_dataset(dataset, train_ind_list, model_flag=args.model_flag)
    print(X_adj_train[0])
    print(X_ops_train[0])
    print(X_indegree[0])
    print(X_outdegree[0])
    print(indices_train[0])
    pretraining_model(dataset, cfg, args)