import os
import spacy
from spacy import displacy

import sys
sys.path.insert(0, '../../utils/')
from utils_AURC import (
    InputFeatures_Constituent, 
    parse_data)

from utils_GNN_CRF import (
    compute_tab_logits_pred_sentence_level,
    extract_sentence_level_label,
    F1_Loss,
    find_max_size_batch_without_nt,
    remove_internal_nodes,
    remove_internal_nodes_without_labels,
    reshape_label_to_batch_padded,
    reshape_mask_to_batch_padded,
    reshape_data_to_batch_padded,
    find_the_most_frequent_element,
    add_dimensions_null_to_embedding,
    construct_the_edges_matrix_consistuency_padding,
    find_the_most_frequent_element
)

from tqdm import tqdm
import time
import random
import datetime
from collections import defaultdict
import numpy as np

# Data Loading
import os.path as osp
import json
from transformers import (AdamW, BertConfig, get_linear_schedule_with_warmup, 
                                BertForTokenClassification, BertTokenizer, BertPreTrainedModel, BertModel)

# Pytorch Module
import torch
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU

from torch_geometric.nn import GCNConv,GATConv, SAGEConv,Sequential,AGNNConv,global_mean_pool
from torch_geometric.data import Dataset, download_url,DataLoader, Data
from torch_geometric.utils import to_undirected, sort_edge_index, to_networkx

# Torch CRF Module 
from torchcrf import CRF

# Optuna Module
import optuna
from optuna.trial import TrialState
from optuna.importance import get_param_importances

# F1-Score
from sklearn.metrics import f1_score

# Parsing
import argparse


## Define the Class to load the data
class MyOwnDataset(Dataset):
    def __init__(self, num_data, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.num_data = num_data
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]
    def len(self):
        return len(self.processed_file_names)
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

    
# Define the Sequential Network
def define_model_GATConv(trial, CLASSES,NODE_FEATURES):
    n_layers = trial.suggest_int("n_layers", 1, 3)
    layers = []
    in_features = NODE_FEATURES

    out_features = trial.suggest_int("n_units_l{}".format(0), 50, 300)
    heads = trial.suggest_int("head_l{}".format(0), 1, 3)
    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )
    in_features = out_features*heads
    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )

    for i in range(1,n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 50, 300)
        heads = trial.suggest_int("head_l{}".format(i), 1, 3)
        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))
        layers.append(nn.ReLU(inplace=True))
        p = trial.suggest_float("dropout_l{}".format(i), 0.01, 0.1)
        layers.append(nn.Dropout(p))
        in_features = out_features*heads

    n_layers_linear = trial.suggest_int("n_layers_linear", 1, 3)
    for i in range(1,n_layers_linear):
        out_features = trial.suggest_int("n_units_lin_l{}".format(i), 50, 250)
        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))
        layers.append(nn.ReLU(inplace=True))
        in_features = out_features

    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))
    return Sequential('x, edge_index',layers)


class Net_large_CRF_end2end(torch.nn.Module):
    def __init__(self,trial,num_node_features, num_labels = 3, batch_first=True, output_hidden_states=True,
            output_attentions=False, max_length = 64, model_name="bert-large-uncased"):
        super(Net_large_CRF_end2end, self).__init__()
        
        self.num_labels = num_labels
        self.batch_first = batch_first
        self.max_length = max_length
        self.num_node_features = num_node_features
        
        self.tokenbert = BertForTokenClassification.from_pretrained(
            model_name,
            num_labels=self.num_labels,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions
        )
        
        self.GNN_attention = define_model_GATConv(trial, self.num_labels, self.num_node_features)
        
        self.crf = CRF(self.num_labels, batch_first=self.batch_first)

    def forward(self, data, labels=None):
        max_size_actu = 21
        
        edge_index  = data.edge_index
        
        input_ids, attention_mask, token_type_ids = data.input_ids, data.attention_mask, data.token_type_ids
        
        outputs= self.tokenbert.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        ).last_hidden_state 
        
        hidden_states = torch.zeros((outputs.size(0),self.max_length+max_size_actu,self.num_node_features))
        hidden_states[:,max_size_actu:,:] = outputs
        hidden_states = hidden_states.view(-1,self.num_node_features)
        
            
        out = self.GNN_attention(hidden_states, edge_index)
        out = out.view(-1,self.max_length+max_size_actu,3)
        
        if(labels is not None):    
            logits = out[:,max_size_actu:,:]
            labels = labels.view((-1,self.max_length))
            return -self.crf(emissions = logits, tags=labels,mask=data.attention_mask.byte())
        else:
            logits = out[:,max_size_actu:,:]
            return self.crf.decode(emissions =logits, mask=data.attention_mask.byte())
        

# Define The optuna objective function
def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):
    time_string = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    # Choose the model.
    model = Net_large_CRF_end2end(trial, NODE_FEATURES).to(DEVICE)
    
    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["AdamW"])
    lr = trial.suggest_float("lr", 1e-6, 1e-4, log=True)
    
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
    
    optimizer = getattr(optim, optimizer_name)(optimizer_grouped_parameters, lr=lr)
    
    lmbda = lambda epoch: 0.99
    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e+1)
    
    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        for idx, (ele) in enumerate(train_loader):
            ele = ele.to(DEVICE)
            
            optimizer.zero_grad()
            loss = model(ele,ele["y"])

            loss.backward()
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
        
        model.eval()
        
        # initalize the variable to compute the F1 score at token level
        total_class_predict_token = []
        total_true_class_token = []

        with torch.no_grad():
            for idx, (ele) in enumerate(dev_loader):
                ele = ele.to(DEVICE)
                #ele["x"] = ele["x"].float()
                class_predict = model(ele)

                ## Flatten the result
                class_predict_flat = [item for sublist in class_predict for item in sublist]
                total_class_predict_token.extend(class_predict_flat)
                
                ## Extract the true label without the padding
                attention_mask = ele["attention_mask"].flatten()
                ele_y_clean = torch.where(attention_mask  > 0, ele["y"], 10)
                ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()
                total_true_class_token.extend(ele_y_clean)

            f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average="macro")
            trial.report(f1_token_level, epoch)
            scheduler.step()
            
    f = open(time_string+".txt", "a")
    f.write("f1 score on token level dev is : " + str(f1_token_level))
    f.close()

    return f1_token_level


# Define the main function to call by script
def main(actual_number_tested):
    
    batch_size = 190
    num_trials = 300
    depth_list_params = [
        {"data_dir":'../../data/aurc/bert/end_2_end_depth_2_Cross',
        "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_2",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2.db"},
        
        {"data_dir":"../../data/aurc/bert/end_2_end",
        "study_name":"Trained_BERT_GNN_CRF_endtoend",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend.db"},
        
        {"data_dir":'../../data/aurc/bert/end_2_end_depth_4_Cross',
        "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_4",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4.db"},
        
        {"data_dir":'../../data/aurc/bert/end_2_end_depth_2_IN',
        "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_2_IN",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2_IN.db"},
        
        {"data_dir":"../../data/aurc/bert/end_2_end_IN",
        "study_name":"Trained_BERT_GNN_CRF_endtoend_IN",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_IN.db"},
        
        {"data_dir":'../../data/aurc/bert/end_2_end_depth_4_IN',
        "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_4_IN",
        "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4_IN.db"}
    ]

    data_dir = depth_list_params[actual_number_tested]["data_dir"]
    EPOCHS = 8
    NODE_FEATURES = 1024
    DEVICE = "cpu"
    CLASSES = 3
    timeout = 3600*24
    num_data_train = 4157
    num_data_dev = 593
    num_data_test = 1189
    #num_data_train = 3960
    #num_data_dev = 790
    #num_data_test = 1959
    
    train_dataset = MyOwnDataset(num_data = num_data_train,root = data_dir+"/Train")
    test_dataset = MyOwnDataset(num_data = num_data_test,root = data_dir+"/Test")
    dev_dataset = MyOwnDataset(num_data = num_data_dev,root = data_dir+"/Dev")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    ## Initialize the optuna research
    study = optuna.create_study(
        direction="maximize",
        study_name = depth_list_params[actual_number_tested]["study_name"], 
        storage = depth_list_params[actual_number_tested]["storage"],
        load_if_exists=True)
    
    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=num_trials, timeout=timeout)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))
    
    print("Name of the study : " + study.study_name)
    
    print("  Params Importance: ")
    dict_params = get_param_importances(study)
    for key, value in dict_params.items():
        print("    {}: {}".format(key, value))
        
    ## Initialize the optuna research

    depth_list_params = [
            {"data_dir":'../../data/aurc/bert/end_2_end_depth_2_Cross',
            "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_2",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2.db"},

            {"data_dir":"../../data/aurc/bert/end_2_end",
            "study_name":"Trained_BERT_GNN_CRF_endtoend",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend.db"},

            {"data_dir":'../../data/aurc/bert/end_2_end_depth_4_Cross',
            "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_4",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4.db"},


            {"data_dir":'../../data/aurc/bert/end_2_end_depth_2_IN',
            "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_2_IN",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_2_IN.db"},

            {"data_dir":"../../data/aurc/bert/end_2_end_IN",
            "study_name":"Trained_BERT_GNN_CRF_endtoend_IN",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_IN.db"},

            {"data_dir":'../../data/aurc/bert/end_2_end_depth_4_IN',
            "study_name":"Trained_BERT_GNN_CRF_endtoend_depth_4_IN",
            "storage":"sqlite:///../optuna_db/Trained_BERT_large_endtoend_depth_4_IN.db"}
        ]
    actual_number_tested = 2

    study = optuna.create_study(
        direction="maximize",
        study_name = depth_list_params[actual_number_tested]["study_name"], 
        storage = depth_list_params[actual_number_tested]["storage"],
        load_if_exists=True)

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    print("Name of the study : " + study.study_name)

    print("  Params Importance: ")
    dict_params = get_param_importances(study)
    for key, value in dict_params.items():
        print("    {}: {}".format(key, value))
        
        
if __name__ == "__main__":
    main(2)   
