import os
import sys
sys.path.insert(0, os.getcwd())
import torch
import numpy as np
import torch.nn as nn
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from scipy.stats import pearsonr
import torch.nn.functional as F
from utils.utils import load_json
import circuit.var_config as vc
import pennylane as qml
import matplotlib.pyplot as plt
from circuit.circuit_manager import circuit_qnode
from torch_geometric.nn import GATConv

class AdjPredictor1(nn.Module):

    def __init__(self):
        super(AdjPredictor1, self).__init__()
        self.linear1 = nn.Linear(12*25, 1024)
        self.linear2 = nn.Linear(1024, 768)
        self.linear3 = nn.Linear(768, 512)
        self.linear4 = nn.Linear(512, 256)
        self.linear5= nn.Linear(256, 12*12)
        self.BN1 = nn.LayerNorm(1024)
        self.BN2 = nn.LayerNorm(768)
        self.BN3 = nn.LayerNorm(512)
        self.BN4 = nn.LayerNorm(256)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        #x = torch.sum(x, dim=1)
        feature = x
        x = x.view(-1, 12*25)
        x = self.linear1(x)
        x = torch.relu(self.BN1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        x = torch.relu(self.BN2(x))
        x = self.dropout(x)
        x = self.linear3(x)
        x = torch.relu(self.BN3(x))
        x = self.dropout(x)
        x = self.linear4(x)
        x = torch.relu(self.BN4(x))
        #x = self.dropout(x)
        x = self.linear5(x)
        #x = torch.relu(x)
        return x, feature
    
class AdjPredictor2(nn.Module):

    def __init__(self):
        super(AdjPredictor2, self).__init__()
        self.linear1 = nn.Linear(12*17, 1024)
        self.linear2 = nn.Linear(1024, 768)
        self.linear3 = nn.Linear(768, 512)
        self.linear4 = nn.Linear(512, 256)
        self.linear5= nn.Linear(256, 12*12)
        self.BN1 = nn.LayerNorm(1024)
        self.BN2 = nn.LayerNorm(768)
        self.BN3 = nn.LayerNorm(512)
        self.BN4 = nn.LayerNorm(256)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        #x = torch.sum(x, dim=1)
        feature = x
        x = x.view(-1, 12*17)
        x = self.linear1(x)
        x = torch.relu(self.BN1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        x = torch.relu(self.BN2(x))
        x = self.dropout(x)
        x = self.linear3(x)
        x = torch.relu(self.BN3(x))
        x = self.dropout(x)
        x = self.linear4(x)
        x = torch.relu(self.BN4(x))
        #x = self.dropout(x)
        x = self.linear5(x)
        #x = torch.relu(x)
        return x, feature

def _build_dataset(dataset, list, change_order = False, encode_flag='original'):
    indices = np.random.permutation(list)
    X_adj = []
    X_ops = []
    X_indegree = []
    X_outdegree = []
    op_flag = None
    adj_flag = None
    if encode_flag == 'original':
        op_flag = 'gate_matrix'
        adj_flag = 'adj_matrix'
    else:
        op_flag = 'improved_gate_matrix'
        adj_flag = 'adj_matrix_with_degree'
    for ind in indices:
        if change_order == True:
            new_sort = torch.sort(torch.Tensor(dataset[ind]['outdegree']), descending=True).indices
            X_adj.append(torch.Tensor(dataset[ind][adj_flag])[new_sort, :][:, new_sort])
            X_ops.append(torch.Tensor(dataset[ind][op_flag])[new_sort,:])
            X_indegree.append(torch.Tensor(dataset[ind]['indegree']).unsqueeze(1)[new_sort,:])
            X_outdegree.append(torch.Tensor(dataset[ind]['outdegree']).unsqueeze(1)[new_sort,:])
        else:
            X_adj.append(torch.Tensor(dataset[ind][adj_flag]))
            X_ops.append(torch.Tensor(dataset[ind][op_flag]))
            X_indegree.append(torch.Tensor(dataset[ind]['indegree']).unsqueeze(1))
            X_outdegree.append(torch.Tensor(dataset[ind]['outdegree']).unsqueeze(1))
    X_adj = torch.stack(X_adj)
    X_ops = torch.stack(X_ops)
    X_indegree = torch.stack(X_indegree)
    X_outdegree = torch.stack(X_outdegree)
    return X_adj, X_ops, X_indegree, X_outdegree, torch.Tensor(indices)

if __name__ == '__main__':
    seed = 42
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    dataset = load_json(f"circuit\\data\\data_{vc.num_qubits}_qubits_test.json")
    train_ind_list, val_ind_list, test_ind_list = range(int(len(dataset)*0.8)), range(int(len(dataset)*0.8), int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset))
    X_adj_train, X_ops_train, X_indegree, X_outdegree, indices_train = _build_dataset(dataset, train_ind_list, change_order=False, encode_flag="improved")
    X_adj_val, X_ops_val, X_indegree_val, X_outdegree_val, indices_val = _build_dataset(dataset, val_ind_list, change_order=False, encode_flag="improved")
    X_adj_test, X_ops_test, X_indegree_test, X_outdegree_test, indices_test = _build_dataset(dataset, test_ind_list, change_order=False, encode_flag="improved")
    
    train_data = TensorDataset(X_ops_train, X_adj_train)
    val_data = TensorDataset(X_ops_val, X_adj_val)
    test_data = TensorDataset(X_ops_test, X_adj_test)
    
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, drop_last=False)
    val_loader = DataLoader(val_data, batch_size=100, shuffle=False, drop_last=False)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False, drop_last=False)
    
    model = AdjPredictor1().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_func = nn.MSELoss()
    
    for i in range(200):

        model.train()
        train_loss = []
        for step, (b_x, b_y) in enumerate(train_loader):
            optimizer.zero_grad()
            b_x, b_y = b_x.cuda(), b_y.cuda()
            forward, _ = model(b_x)
            forward = forward.view(-1, 12, 12)
            loss = loss_func(forward, b_y)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        if (i+1) % 10 == 0:
            print(f"epoch:{i+1}")
            print('train_loss:{:.5f}'.format(sum(train_loss)/len(train_loss)))

        valid_loss = []
        model.eval()
        for step, (b_x, b_y) in enumerate(val_loader):
            b_x, b_y = b_x.cuda(), b_y.cuda()
            with torch.no_grad():
                forward, _ = model(b_x)
                forward = forward.view(-1, 12, 12)
            loss = loss_func(forward, b_y)
            valid_loss.append(loss.item())
        
        if (i+1) % 10 == 0:
            print('valid_loss:{:.5f}'.format(sum(valid_loss) / len(valid_loss)))
            #x_index = int(indices_val[0].item())
            #circuit = dataset[x_index]['op_list']
            #x_test = X_ops_val[0].cuda()
            #adj_test = X_adj_val[0].cuda()
            #adj_pred, _ = model(x_test)
            #print(x_test)
            #print(adj_test)
            #print(adj_pred.view(12, 12))
            #print(adj_pred.view(12, 12).round())

        if (i+1) % 10 == 0:
            count = 0
            for step, (b_x, b_y) in enumerate(test_loader):
                b_x, b_y = b_x.cuda(), b_y.cuda()
                with torch.no_grad():
                    forward, _ = model(b_x)
                    forward = forward.view(-1, 12, 12).squeeze(0)
                if torch.all(forward.round().eq(b_y.squeeze(0))):
                    count += 1
            print(f"correct prediction: {count/10000}")