# Data Loading
import os.path as osp

# Pytorch Module
import torch
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GATConv, SAGEConv, Sequential
from torch_geometric.data import Dataset, download_url,DataLoader
from torch_geometric.utils import to_undirected

# Optuna Module
import optuna
from optuna.trial import TrialState
from optuna.importance import get_param_importances

# F1-Score
from sklearn.metrics import f1_score

# Parsing
import argparse

## Define the Class to load the data
class MyOwnDataset(Dataset):
    def __init__(self, num_data, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.num_data = num_data
    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]
    def len(self):
        return len(self.processed_file_names)
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

# F1 loss
class F1_Loss(nn.Module):
    '''Calculate F1 score. Can work with gpu tensors
    Returns
    -------
    torch.Tensor
        `ndim` == 1. epsilon <= val <= 1
    '''
    
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon
        
    def forward(self, y_pred, y_true,):
        assert y_pred.ndim == 2
        assert y_true.ndim == 1
        y_true = F.one_hot(y_true, 3).to(torch.float32)
        y_pred = F.softmax(y_pred, dim=1)
        
        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2* (precision*recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1-self.epsilon)
        return 1 - f1.mean()

    
# Define the Sequential Network
def define_model_GATConv(trial, CLASSES,NODE_FEATURES):
    n_layers = trial.suggest_int("n_layers", 1, 5)
    layers = []
    in_features = NODE_FEATURES
    
    out_features = trial.suggest_int("n_units_l{}".format(0), 10, 500)
    heads = trial.suggest_int("head_l{}".format(0), 1, 6)
    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )
    in_features = out_features*heads
    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )

    for i in range(1,n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 10, 500)
        heads = trial.suggest_int("head_l{}".format(i), 1, 6)
        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))
        layers.append(nn.ReLU(inplace=True))
        p = trial.suggest_float("dropout_l{}".format(i), 0.01, 0.2)
        layers.append(nn.Dropout(p))
        in_features = out_features*heads
    
    n_layers_linear = trial.suggest_int("n_layers_linear", 0, 4)
    for i in range(1,n_layers_linear):
        out_features = trial.suggest_int("n_units_lin_l{}".format(i), 4, 250)
        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))
        layers.append(nn.ReLU(inplace=True))
        in_features = out_features
        
    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))
    return Sequential('x, edge_index',layers)


def define_model_SAGEConv(trial, CLASSES,NODE_FEATURES):
    n_layers = trial.suggest_int("n_layers", 1, 6)
    layers = []
    in_features = NODE_FEATURES
    
    out_features = trial.suggest_int("n_units_l{}".format(0), 4, 500)
    layers.append( (SAGEConv(in_features, out_features), 'x, edge_index -> x') )
    in_features = out_features
    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )

    
    for i in range(1,n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 500)
        layers.append((SAGEConv(in_features, out_features), 'x, edge_index -> x'))
        layers.append(nn.ReLU(inplace=True))
        p = trial.suggest_float("dropout_l{}".format(i), 0.01, 0.2)
        layers.append(nn.Dropout(p))
        in_features = out_features
    
    n_layers_linear = trial.suggest_int("n_layers_linear", 0, 4)
    for i in range(1,n_layers_linear):
        out_features = trial.suggest_int("n_units_lin_l{}".format(i), 4, 250)
        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))
        layers.append(nn.ReLU(inplace=True))
        in_features = out_features
        
    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))
    return Sequential('x, edge_index',layers)


def define_model_GCNConv(trial, CLASSES,NODE_FEATURES):
    n_layers = trial.suggest_int("n_layers", 1, 6)
    layers = []
    in_features = NODE_FEATURES
    
    out_features = trial.suggest_int("n_units_l{}".format(0), 4, 500)
    layers.append( (GCNConv(in_features, out_features), 'x, edge_index -> x') )
    in_features = out_features
    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )
    
    for i in range(1,n_layers):
        out_features = trial.suggest_int("n_units_l{}".format(i), 4, 250)
        layers.append((GCNConv(in_features, out_features), 'x, edge_index -> x'))
        layers.append(nn.ReLU(inplace=True))
        p = trial.suggest_float("dropout_l{}".format(i), 0.001, 0.2)
        layers.append(nn.Dropout(p))
        in_features = out_features
    
    n_layers_linear = trial.suggest_int("n_layers_linear", 0, 4)
    for i in range(1,n_layers_linear):
        out_features = trial.suggest_int("n_units_lin_l{}".format(i), 4, 250)
        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))
        layers.append(nn.ReLU(inplace=True))
        in_features = out_features
        
    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))
    return Sequential('x, edge_index',layers)


# Define The optuna objective function
def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):
    
    #f1_loss = F1_Loss().to(DEVICE)
    # Choose the model.
    classifier_name = trial.suggest_categorical('classifier',['GAT']) #'GCN','Sage',
    if(classifier_name == 'GCN'):
        model = define_model_GCNConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)
    elif(classifier_name == 'Sage'):
        model = define_model_SAGEConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)
    else:
        model = define_model_GATConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)
    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["AdamW"])
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
    
    lmbda = lambda epoch: 0.99
    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
    max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e+1)
    # Training of the model.
    val_to_replace = torch.tensor(-1e9, dtype=torch.float32).to(DEVICE)
    for epoch in range(EPOCHS):
        model.train()
        for batch_idx, (ele) in enumerate(train_loader):
            optimizer.zero_grad()
            ele = ele.to(DEVICE)
            ele["x"] = ele["x"].float()
            ele["y"] = ele["y"].long()
            #ele["edge_index"] =  to_undirected(ele["edge_index"])
            out = model(ele.x, ele.edge_index)
            
            #ele_y_clean = torch.where(ele["indice_nt_tensor"] > 0, ele["y"], 10)
            #ele_y_clean = ele_y_clean[ele_y_clean!=10]
            #ele_batch_matrix = ele["indice_nt_tensor"].repeat(3, 1).transpose(0,1)
            #out_clean = torch.where(ele_batch_matrix.to(torch.float32) > 0, out, val_to_replace)
            #out_clean = out_clean[out_clean!=10].view(-1,3)
            
            loss = nn.CrossEntropyLoss()
            #loss_calc = f1_loss(out, ele["y"])
            loss_calc = loss(out, ele["y"])
            loss_calc.backward()
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            optimizer.step()
        # Validation of the model.
        model.eval()
        total_acc, total_count =  0, 0
        total_class_predict_token = []
        total_true_class_token = []
        with torch.no_grad():
            for batch_idx, (ele) in enumerate(dev_loader):
                # Limiting validation data.
                ele = ele.to(DEVICE)
                ele["x"] = ele["x"].float()
                ele["y"] = ele["y"].long()
                #ele["edge_index"] =  to_undirected(ele["edge_index"])
                out = model(ele.x, ele.edge_index)
                
                ele_y_clean = torch.where(ele["indice_nt_tensor"] > 0, ele["y"], -10000)
                ele_y_clean = ele_y_clean[ele_y_clean!=-10000]

                ele_batch_matrix = ele["indice_nt_tensor"].repeat(3, 1).transpose(0,1)
                out_clean = torch.where(ele_batch_matrix.to(torch.float32) > 0, out, val_to_replace)
                out_clean = out_clean[out_clean!=-1e9].view(-1,3)
                class_predict_clean = out_clean.argmax(1).cpu()
                
                #total_acc += (out_clean.argmax(1) == ele_y_clean).sum().item()
                #total_count += ele_y_clean.size(0)
                
                total_true_class_token.extend(ele_y_clean.tolist())
            
                total_class_predict_token.extend(class_predict_clean.tolist())
                
        #accuracy = total_acc/total_count
        f1_score_token = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average="micro")
        trial.report(f1_score_token, epoch)
        # Handle pruning based on the intermediate value.
        #if trial.should_prune() :
        #    raise optuna.exceptions.TrialPruned()
        scheduler.step()
    return f1_score_token


# Define the main function to call by script
def main():
    
    parser = argparse.ArgumentParser(description='Optuna Optimization for GrammarNet')
    
    parser.add_argument('-b', '--batch_size', type=int, default=190, help='Size of the batch')
    parser.add_argument('-c', '--num_trials', type=int, default=500, help='Number of Optuna Trials')
    parser.add_argument('-d', '--data_dir', type=str, default='../data/aurc/bert/large_depth', help='Localisation fo the folders Test, Dev, Env containing data')
    parser.add_argument('-e', '--epoch', type=int, default=25, help='Number of EPOCH for each trial.')
    parser.add_argument('-f', '--node_features', type=int, default=1024, help='Size of the input features.')
    parser.add_argument('-g', '--gpu_index', type=int, default=2, help='Index of the GPU to choose')
    parser.add_argument('-n', '--num_labels', type=int, default=3, help='Using either 3 (pro, con, non) or 2 (arg, non) labels.')
    parser.add_argument('-t', '--timeout', type=int, default=3600, help='Optuna Timeout')
    parser.add_argument('--num_data_train', type=int, default=3960, help='Size of the train dataset')
    parser.add_argument('--num_data_dev', type=int, default=790, help='Size of the dev dataset')
    parser.add_argument('--num_data_test', type=int, default=1959, help='Size of the test dataset')
    
    args = parser.parse_args()
    
    ## Define the parameters
    batch_size = args.batch_size
    CLASSES = args.num_labels
    DEVICE = torch.device('cuda:' + str(args.gpu_index) if torch.cuda.is_available() else 'cpu')
    data_dir = args.data_dir
    EPOCHS = args.epoch
    NODE_FEATURES = args.node_features
    
    ## Load the Data
    train_dataset = MyOwnDataset(num_data = args.num_data_train,root = data_dir+"/Train")
    test_dataset = MyOwnDataset(num_data = args.num_data_test,root = data_dir+"/Test")
    dev_dataset = MyOwnDataset(num_data = args.num_data_dev,root = data_dir+"/Dev")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    ## Initialize the optuna research
    study = optuna.create_study(
        direction="maximize",
        study_name = "Trained_BERT_3_labels_large_first_loop_no_prune_depth", 
        storage="sqlite:///optuna_db/Trained_BERT_3_labels_large_first_loop_no_prune_depth.db",
        load_if_exists=True)
    
    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=args.num_trials, timeout=args.timeout)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))
    
    print("Name of the study : " + study.study_name)
    
    print("  Params Importance: ")
    dict_params = get_param_importances(study)
    for key, value in dict_params.items():
        print("    {}: {}".format(key, value))
        
if __name__ == "__main__":
    main()        

