import sys
import model
import torch
import torch.nn as nn
import dataloader
from tqdm import tqdm
import os
import pytorch_lightning as pl
import graph_gnn_layer
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForPreTraining, AutoModelWithLMHead, AutoModel

class KGAP_trainer(pl.LightningModule):
    def __init__(self, dataset, fold, hyperparams):
        super().__init__()
        self.part = hyperparams["part"]
        self.dataset = dataset
        self.fold = fold
        self.lm_name = hyperparams["lm_name"]
        self.log_dir = hyperparams["log_dir"]
        self.dropout_p = hyperparams["dropout_p"]
        self.lm_freeze_epochs = hyperparams["lm_freeze_epochs"]
        self.in_channels = 768 # the KGAP used pre-trained RoBERTa

        if self.dataset == "SemEval":
            self.num_class = 2
        elif self.dataset == "Allsides":
            self.num_class = 3
        elif self.dataset == "FND" and fold == 0: # two class setting
            self.num_class = 2
        elif self.dataset == "FND" and fold == 1: # four class setting
            self.num_class = 4
        elif self.dataset == "FC":
            self.num_class = 2
        elif self.dataset == "RCVP":
            self.num_class = 2
        else:
            print("nope")
            exit()
        self.gnn1 = graph_gnn_layer.GatedRGCN(self.in_channels, self.in_channels, 3)
        self.gnn2 = graph_gnn_layer.GatedRGCN(self.in_channels, self.in_channels, 3)
        self.Linear1 = nn.Linear(self.in_channels, self.in_channels)
        if self.dataset == "RCVP":
            self.Linear2 = nn.Linear(2 * self.in_channels, self.num_class)
        else:
            self.Linear2 = nn.Linear(self.in_channels, self.num_class)
        self.activation = nn.SELU()
        self.dropout = nn.Dropout(self.dropout_p)

        self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
        self.model = AutoModel.from_pretrained("roberta-base")

        self.CELoss = nn.CrossEntropyLoss()

        self.valmaxacc = 0
        self.valmaxf1 = 0

        if self.dataset == "RCVP":
            self.linear_rcvp = nn.Sequential(nn.Linear(512, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)
            self.RingLM2in = nn.Sequential(nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)

    def lm_extract(self, text):
        inputs = self.tokenizer(text, truncation = True, padding = True, max_length = 200, return_tensors="pt")
        return torch.mean(self.model(inputs["input_ids"].cuda()).last_hidden_state, dim=1)

    def forward(self, x):
        temp = list(x["graph"]["edge_type"])
        num_sent = 0
        for t in temp:
            if t == 0:
                num_sent += 1
            else:
                break

        new_feature = self.dropout(self.activation(self.Linear1(x["graph"]["node_features"])))
        new_feature = self.dropout(self.activation(self.gnn1(new_feature, x["graph"]["edge_index"], x["graph"]["edge_type"])))
        new_feature = self.dropout(self.activation(self.gnn2(new_feature, x["graph"]["edge_index"], x["graph"]["edge_type"])))
        new_feature = torch.mean(new_feature[1:num_sent+1], dim=0)
        if self.dataset == "FC":
            new_feature = torch.cat((self.lm_extract(x["summary"]), new_feature), dim = 1)
        if self.dataset == "RCVP":
            summary_vec = self.linear_rcvp(x["leg_rep"]) # k * 768
            new_feature = torch.cat((summary_vec, self.RingLM2in(new_feature).repeat(len(summary_vec), 1)), dim = 1)
        if self.dataset == "RCVP":
            new_feature = self.Linear2(new_feature)
        else:
            new_feature = self.Linear2(new_feature).unsqueeze(0)
        return new_feature

    def configure_optimizers(self):
        optimizer = torch.optim.RAdam(self.parameters(), lr = 1e-3, weight_decay = 1e-5)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in train_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        batch_loss /= len(train_batch)
        return batch_loss

    def validation_step(self, val_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in val_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")
        batch_loss /= len(val_batch)
        self.log("val_loss", batch_loss)
        if acc > self.valmaxacc:
            self.valmaxacc = acc
            self.valmaxf1 = f1
        return batch_loss

    def test_step(self, test_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []

        correct = []
        incorrect = []

        for input in test_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss

            if pred[-1] == truth[-1]:
                correct.append(input["id"])
            else:
                incorrect.append(input["id"])
        torch.save({"correct": correct, "incorrect": incorrect}, "testrecord/" + self.dataset + str(self.fold) + "KGAP")


        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred, average = "binary")
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")
        mif1 = f1_score(truth, pred, average = "micro")
        mapre = precision_score(truth, pred, average = "macro")
        marec = recall_score(truth, pred, average = "macro")
        batch_loss /= len(test_batch)
        self.log("test_acc", acc)

        # logging
        f = open(self.log_dir, "a")
        now = datetime.now()
        dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
        if self.dataset == "SemEval" or self.dataset == "Allsides" or self.dataset == "RCVP":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Time: " + dt_string + " Part: " + str(part) + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test accuracy: " + str(acc) + " Test F1-score: " + str(f1) + "\n")
            f.write("--------------------\n")
            f.close()
        elif self.dataset == "FND":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Time: " + dt_string + " Part: " + str(part) + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test mif1: " + str(mif1) + " Test maf1: " + str(f1) + " Test mapre: " + str(mapre) + " Test marec: " + str(marec) + "\n")
            f.write("--------------------\n")
            f.close()

def train_once(dataset, fold, hyperparams):
    train_loader, dev_loader, test_loader = dataloader.get_dataloaders_KGAP(dataset, fold, hyperparams["batch_size"], hyperparams["part"])
    model = KGAP_trainer(dataset, fold, hyperparams)
    trainer = pl.Trainer(gpus = 1, num_nodes = 1, precision = 16, max_epochs = hyperparams["max_epochs"])
    trainer.fit(model, train_loader, dev_loader)
    trainer.test(model, test_loader)



dataset = "FND" # change
hyperparams = {
    "optimizer": "RAdam",
    "lm_name": "roberta",
    "in_channels": 128,
    "lr_lm": 1e-5,
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "lm_freeze_epochs": 1000,
    "log_dir": "logs/log_KGAP.txt", # change
    "ringlm_layer": 2,
    "nhead": 1,
    "dropout_p": 0.5,
    "batch_size": 16,
    "max_epochs": 3,
    "part": 1
}

for part in [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
    hyperparams["part"] = part
    #for fold in range(3): # change
    fold = 0
    train_once(dataset, fold, hyperparams)
# part = 1
# for fold in [1]:
#     train_once(dataset, fold, hyperparams)