# Data Loading
import os.path as osp

# Pytorch Module
import torch
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GATConv, SAGEConv, Sequential
from torch_geometric.data import Dataset, download_url,DataLoader
from torch_geometric.utils import to_undirected

# Optuna Module
import optuna
from optuna.trial import TrialState
from optuna.importance import get_param_importances

# F1-Score
from sklearn.metrics import f1_score

# Parsing
import argparse

import sys
sys.path.insert(0, '../../utils/')
from utils_GNN_CRF import (
    compute_tab_logits_pred_sentence_level,
    extract_sentence_level_label,
    F1_Loss,
    find_max_size_batch_without_nt,
    remove_internal_nodes,
    remove_internal_nodes_without_labels,
    reshape_label_to_batch_padded,
    reshape_mask_to_batch_padded,
    reshape_data_to_batch_padded,
    find_the_most_frequent_element
)

from torchcrf import CRF

## Define the Class to load the data
class MyOwnDataset(Dataset):
    def __init__(self, num_data, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.num_data = num_data
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]
    def len(self):
        return len(self.processed_file_names)
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data
    
    
# Define the Sequential Network
def define_model_GATConv(trial, CLASSES,NODE_FEATURES):
    n_layers = trial.suggest_int("n_layers", 1, 5)
    layers = []
    in_features = NODE_FEATURES

    out_features = trial.suggest_int("n_units_l{}".format(0), 10, 100)
    heads = trial.suggest_int("head_l{}".format(0), 1, 3)
    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )
    in_features = out_features*heads
    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )

    for i in range(1,n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 10, 100)
        heads = trial.suggest_int("head_l{}".format(i), 1, 3)
        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))
        layers.append(nn.ReLU(inplace=True))
        p = trial.suggest_float("dropout_l{}".format(i), 0.01, 0.2)
        layers.append(nn.Dropout(p))
        in_features = out_features*heads

    n_layers_linear = trial.suggest_int("n_layers_linear", 0, 2)
    for i in range(1,n_layers_linear):
        out_features = trial.suggest_int("n_units_lin_l{}".format(i), 4, 200)
        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))
        layers.append(nn.ReLU(inplace=True))
        in_features = out_features

    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))
    return Sequential('x, edge_index',layers)


class Net_large_CRF(torch.nn.Module):
    def __init__(self,trial,num_node_features, num_labels = 3, batch_first=True):
        super(Net_large_CRF, self).__init__()
        
        self.num_labels = num_labels
        self.batch_first = batch_first
        
        self.GNN_attention = define_model_GATConv(trial, self.num_labels, num_node_features)
        
        self.crf = CRF(self.num_labels, batch_first=self.batch_first)

    def forward(self, data, labels=None):
        max_size_actu = 62
        x, edge_index, batch, indice_nt_tensor = data.x, data.edge_index, data.batch, data.indice_nt_tensor
        
        out = self.GNN_attention(x, edge_index)
        
        if(labels is not None):
            sample_x_new,sample_batch_new,sample_label_new = remove_internal_nodes(indice_nt_tensor,out,batch,labels)
            labels, padding_mask = reshape_label_to_batch_padded(max_size_actu,sample_batch_new,sample_label_new)
            logits = reshape_data_to_batch_padded(max_size_actu,sample_batch_new,sample_x_new,3)
            return -self.crf(emissions = logits, tags=labels,mask=padding_mask.byte())
        else:
            sample_x_new,sample_batch_new = remove_internal_nodes_without_labels(indice_nt_tensor,out,batch)
            
            padding_mask = reshape_mask_to_batch_padded(max_size_actu,sample_batch_new)
            
            logits = reshape_data_to_batch_padded(max_size_actu,sample_batch_new,sample_x_new,3)
            return self.crf.decode(emissions =logits, mask=padding_mask.byte())
        return x
    
    
# Define The optuna objective function
def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):
    
    # Choose the model.
    model = Net_large_CRF(trial, NODE_FEATURES).to(DEVICE)
    
    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["AdamW"])
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    
    lmbda = lambda epoch: 0.99
    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e+1)
    
    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        for idx, (ele) in enumerate(train_loader):
            ele = ele.to(DEVICE)
            ele["x"] = ele["x"].float()
            optimizer.zero_grad()
            loss = model(ele,ele["y"])

            loss.backward()
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
        
        model.eval()
        # initalize the variable to compute the F1 score at token level
        total_class_predict_token = []
        total_true_class_token = []

        with torch.no_grad():
            for idx, (ele) in enumerate(dev_loader):
                ele = ele.to(DEVICE)
                ele["x"] = ele["x"].float()
                class_predict = model(ele)

                ## Flatten the result
                class_predict_flat = [item for sublist in class_predict for item in sublist]
                total_class_predict_token.extend(class_predict_flat)

                ## Extract the true label without the internal nodes
                ele_y_clean = torch.where(ele["indice_nt_tensor"] > 0, ele["y"], 10)
                ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()
                total_true_class_token.extend(ele_y_clean)
                

            f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average="macro")
            trial.report(f1_token_level, epoch)
            scheduler.step()

    return f1_token_level


# Define the main function to call by script
def main():
    
    batch_size = 190
    num_trials = 200
    #data_dir = '../../data/aurc/bert/large_depth_all_connected'
    data_dir = '../../data/aurc/bert/large_depth_IN_connected' 
    EPOCHS = 15
    NODE_FEATURES = 1024
    DEVICE = "cpu"#torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    CLASSES = 3
    timeout = 3600*8
    num_data_train = 3960
    num_data_dev = 790
    num_data_test = 1959
    
    train_dataset = MyOwnDataset(num_data = num_data_train,root = data_dir+"/Train")
    test_dataset = MyOwnDataset(num_data = num_data_test,root = data_dir+"/Test")
    dev_dataset = MyOwnDataset(num_data = num_data_dev,root = data_dir+"/Dev")
    
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    ## Initialize the optuna research
    study = optuna.create_study(
        direction="maximize",
        study_name = "Trained_BERT_GNN_CRF_IN_connected", 
        storage="sqlite:///../optuna_db/Trained_BERT_large_GNN_CRF_IN_connected.db",
        load_if_exists=True)
    
    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=num_trials, timeout=timeout)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))
    
    print("Name of the study : " + study.study_name)
    
    print("  Params Importance: ")
    dict_params = get_param_importances(study)
    for key, value in dict_params.items():
        print("    {}: {}".format(key, value))
 
        
if __name__ == "__main__":
    main()        
