import sys
import model
import torch
import torch.nn as nn
import dataloader
from tqdm import tqdm
import os
import pytorch_lightning as pl
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from datetime import datetime

class GreaseLM_trainer(pl.LightningModule):
    def __init__(self, dataset, fold, hyperparams):
        super().__init__()
        self.part = hyperparams["part"]
        self.dataset = dataset
        self.fold = fold
        self.optimizer = hyperparams["optimizer"]
        self.lm_name = hyperparams["lm_name"]
        self.in_channels = hyperparams["in_channels"]
        self.lr_lm = hyperparams["lr_lm"]
        self.lr = hyperparams["lr"]
        self.weight_decay = hyperparams["weight_decay"]
        self.lm_freeze_epochs = hyperparams["lm_freeze_epochs"]
        self.log_dir = hyperparams["log_dir"]
        if self.dataset == "SemEval":
            self.num_class = 2
        elif self.dataset == "Allsides":
            self.num_class = 3
        elif self.dataset == "FND" and fold == 0: # two class setting
            self.num_class = 2
        elif self.dataset == "FND" and fold == 1: # four class setting
            self.num_class = 4
        elif self.dataset == "FC":
            self.num_class = 2
        elif self.dataset == "RCVP":
            self.num_class = 2
        else:
            print("nope")
            exit()

        self.input_process_layer = model.InputProcess(self.lm_name, self.dataset, self.in_channels)
        self.RingLM_seq = nn.Sequential()
        for i in range(hyperparams["ringlm_layer"]):
            self.RingLM_seq.append(model.GreaseLM(self.in_channels, hyperparams["nhead"], hyperparams["dropout_p"], self.dataset))
        if self.dataset == "RCVP":
            self.LinearOut = nn.Linear(2 * self.in_channels, self.num_class)
        else:
            self.LinearOut = nn.Linear(3 * self.in_channels, self.num_class)
        self.CELoss = nn.CrossEntropyLoss()
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(hyperparams["dropout_p"])

        self.valmaxacc = 0
        self.valmaxf1 = 0

        if self.dataset == "RCVP":
            self.linear_rcvp = nn.Sequential(nn.Linear(512, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)
            self.RingLM2in = nn.Sequential(nn.Linear(3 * self.in_channels, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)

    def forward(self, data):
        x = self.input_process_layer(data)
        x = self.RingLM_seq(x)
        a = x["text"][0].unsqueeze(0)
        b = x["knowledge"]["node_features"][0].unsqueeze(0)
        c = model.att_agg(a, x["knowledge"]["node_features"][1:])
        x = torch.cat((a, torch.cat((b, c), dim = 1)), dim = 1)
        if self.dataset == "FC":
            x = torch.cat((self.input_process_layer.lm_extract(x["summary"]), x), dim = 1)
        if self.dataset == "RCVP":
            summary_vec = self.linear_rcvp(data["leg_rep"]) # k * 768
            x = torch.cat((summary_vec, self.RingLM2in(x).repeat(len(summary_vec), 1)), dim = 1)
        
        x = self.LinearOut(x)
        return x

    def configure_optimizers(self):
        thing = [{"params": self.input_process_layer.model.parameters(), "lr": self.lr_lm},
                 {"params": self.input_process_layer.LinearT.parameters()}, {"params": self.input_process_layer.LinearK.parameters()},
                 {"params": self.input_process_layer.LinearG.parameters()}, {"params": self.RingLM_seq.parameters()},
                 {"params": self.LinearOut.parameters()}]
        if self.optimizer == "RAdam":
            optimizer = torch.optim.RAdam(thing, lr = self.lr, weight_decay = self.weight_decay)
        elif self.optimizer == "AdamW":
            optimizer = torch.optim.AdamW(thing, lr = self.lr, weight_decay = self.weight_decay)
        elif self.optimizer == "Adam":
            optimizer = torch.optim.Adam(thing, lr = self.lr, weight_decay = self.weight_decay)
        return optimizer

    def on_epoch_start(self):
        if self.current_epoch == 0:
            #self.model.freeze()
            for param in self.input_process_layer.model.parameters():
                param.requires_grad = False
        
        if self.current_epoch == self.lm_freeze_epochs:
            #self.model.unfreeze()
            for param in self.input_process_layer.model.parameters():
                param.requires_grad = True

    def training_step(self, train_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in train_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        batch_loss /= len(train_batch)
        return batch_loss

    def validation_step(self, val_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in val_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")
        self.log("val_loss", batch_loss)
        self.log("val_acc", acc)
        if acc > self.valmaxacc:
            self.valmaxacc = acc
            self.valmaxf1 = f1
        return batch_loss

    def test_step(self, test_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []

        correct = []
        incorrect = []

        for input in test_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
            
            if pred[-1] == truth[-1]:
                correct.append(input["id"])
            else:
                incorrect.append(input["id"])
        torch.save({"correct": correct, "incorrect": incorrect}, "testrecord/" + self.dataset + str(self.fold) + "baseline")


        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred, average = "binary")
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")    
        mif1 = f1_score(truth, pred, average = "micro")
        mapre = precision_score(truth, pred, average = "macro")
        marec = recall_score(truth, pred, average = "macro")
        batch_loss /= len(test_batch)
        self.log("test_acc", acc)

        # logging
        f = open(self.log_dir, "a")
        now = datetime.now()
        dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
        if self.dataset == "SemEval" or self.dataset == "Allsides" or self.dataset == "RCVP":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Part: " + str(self.part) + " Time: " + dt_string + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test accuracy: " + str(acc) + " Test F1-score: " + str(f1) + "\n")
            f.write("--------------------\n")
            f.close()
        elif self.dataset == "FND":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Part: " + str(self.part) + " Time: " + dt_string + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test mif1: " + str(mif1) + " Test maf1: " + str(f1) + " Test mapre: " + str(mapre) + " Test marec: " + str(marec) + "\n")
            f.write("--------------------\n")
            f.close()

def train_once(dataset, fold, hyperparams):
    train_loader, dev_loader, test_loader = dataloader.get_dataloaders_trainpart(dataset, fold, hyperparams["batch_size"], hyperparams["part"])
    model = GreaseLM_trainer(dataset, fold, hyperparams)
    trainer = pl.Trainer(gpus = 1, num_nodes = 1, precision = 16, max_epochs = hyperparams["max_epochs"])
    trainer.fit(model, train_loader, dev_loader)
    trainer.test(model, test_loader)



dataset = "RCVP"
hyperparams = {
    "optimizer": "RAdam",
    "lm_name": "bart",
    "in_channels": 128,
    "lr_lm": 1e-5,
    "lr": 1e-4,
    "weight_decay": 1e-5,
    "lm_freeze_epochs": 1000,
    "log_dir": "logs/log_greaselm.txt",
    "ringlm_layer": 2,
    "nhead": 1,
    "dropout_p": 0.5,
    "batch_size": 4,
    "max_epochs": 100, # change
    "part": 1
}

# for part in [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]:
#     hyperparams["part"] = part
#     #for fold in range(1,10):
#     #for fold in [6,7,8,9]:
#     fold = 1
#     train_once(dataset, fold, hyperparams)

while True:
    for fold in [0,1]:
        train_once(dataset, fold, hyperparams)