import sys
import model
import torch
import torch.nn as nn
import dataloader
from tqdm import tqdm
import os
import pytorch_lightning as pl
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, balanced_accuracy_score
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForPreTraining, AutoModelWithLMHead, AutoModel

class LM_trainer(pl.LightningModule):
    def __init__(self, dataset, fold, hyperparams):
        super().__init__()
        self.part = hyperparams["part"]
        self.dataset = dataset
        self.fold = fold
        self.lm_name = hyperparams["lm_name"]
        self.log_dir = hyperparams["log_dir"]
        self.dropout_p = hyperparams["dropout_p"]
        self.lm_freeze_epochs = hyperparams["lm_freeze_epochs"]

        if self.lm_name == "roberta":
            self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")
            self.model = AutoModel.from_pretrained("roberta-base")
            self.in_channels = 768
        if self.lm_name == "deberta":
            self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
            self.model = AutoModel.from_pretrained("microsoft/deberta-v3-base")
            self.in_channels = 768
        if self.lm_name == "electra":
            self.tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
            self.model = AutoModel.from_pretrained("google/electra-small-discriminator")
            self.in_channels = 256
        if self.lm_name == "bart":
            self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
            self.model = AutoModel.from_pretrained("facebook/bart-base")
            self.in_channels = 768
        if self.lm_name == "longformer":
            self.tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
            self.model = AutoModel.from_pretrained("allenai/longformer-base-4096")
            self.in_channels = 768


        if self.dataset == "SemEval":
            self.num_class = 2
        elif self.dataset == "Allsides":
            self.num_class = 3
        elif self.dataset == "FND" and fold == 0: # two class setting
            self.num_class = 2
        elif self.dataset == "FND" and fold == 1: # four class setting
            self.num_class = 4
        elif self.dataset == "FC":
            self.num_class = 2
        elif self.dataset == "RCVP":
            self.num_class = 2
        else:
            print("nope")
            exit()

        self.Linear1 = nn.Linear(self.in_channels, self.in_channels)
        if self.dataset == "RCVP":
            self.Linear2 = nn.Linear(2 * self.in_channels, self.num_class)
        else:
            self.Linear2 = nn.Linear(self.in_channels, self.num_class)
        self.activation = nn.SELU()
        self.dropout = nn.Dropout(self.dropout_p)

        self.CELoss = nn.CrossEntropyLoss()

        self.valmaxacc = 0
        self.valmaxf1 = 0

        if self.dataset == "RCVP":
            self.linear_rcvp = nn.Sequential(nn.Linear(512, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)
            self.RingLM2in = nn.Sequential(nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout, nn.Linear(self.in_channels, self.in_channels), self.activation, self.dropout)

    # def lm_extract(self, text):
    #     inputs = self.tokenizer(text, truncation = True, max_length = 200, return_tensors="pt")
    #     return torch.mean(self.model(inputs["input_ids"].cuda()).last_hidden_state.squeeze(0), dim=0)

    # def forward(self, data):
    #     text = []
    #     for sent in data["original_text"]:
    #         text.append(self.lm_extract(sent))
    #     text = sum(text) / len(text)

    def lm_extract(self, text):
        inputs = self.tokenizer(text, truncation = True, padding = True, max_length = 200, return_tensors="pt")
        return torch.mean(self.model(inputs["input_ids"].cuda()).last_hidden_state, dim=1)
    
    def forward(self, data):

        text = self.lm_extract(data["original_text"])
        text = torch.mean(text, dim=0)


        text = self.dropout(self.activation(self.Linear1(text)))
        if self.dataset == "FC":
            text = torch.cat((self.lm_extract(data["summary"]), text), dim = 1)
        
        if self.dataset == "RCVP":
            summary_vec = self.linear_rcvp(data["leg_rep"]) # k * 768
            text = torch.cat((summary_vec, self.RingLM2in(text).repeat(len(summary_vec), 1)), dim = 1)

        text = self.Linear2(text)
        if not self.dataset == "RCVP":
            text = text.unsqueeze(0)
        return text

    def configure_optimizers(self):
        optimizer = torch.optim.RAdam(self.parameters(), lr = 1e-3, weight_decay = 1e-5)
        return optimizer

    def on_epoch_start(self):
        if self.current_epoch == 0:
            #self.model.freeze()
            for param in self.model.parameters():
                param.requires_grad = False
        
        if self.current_epoch == self.lm_freeze_epochs:
            #self.model.unfreeze()
            for param in self.model.parameters():
                param.requires_grad = True

    def training_step(self, train_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in train_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        batch_loss /= len(train_batch)
        return batch_loss

    def validation_step(self, val_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []
        for input in val_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss
        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred)
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")
        batch_loss /= len(val_batch)
        self.log("val_loss", batch_loss)
        if acc > self.valmaxacc:
            self.valmaxacc = acc
            self.valmaxf1 = f1
        return batch_loss
    
    def test_step(self, test_batch, batch_idx):
        batch_loss = 0
        truth = []
        pred = []

        correct = []
        incorrect = []

        for input in test_batch:
            if self.dataset == "RCVP":
                truth += input["label"]
                logit = self.forward(input)
                pred += [int(x) for x in list(torch.argmax(logit, dim=1))]
                loss = self.CELoss(logit, torch.tensor(input["label"]).long().cuda())
                batch_loss += loss
            else:
                truth.append(input["label"])
                logit = self.forward(input)
                pred.append(int(torch.argmax(logit, dim=1)))
                loss = self.CELoss(logit, torch.tensor([input["label"]]).long().cuda())
                batch_loss += loss

            if pred[-1] == truth[-1]:
                correct.append(input["id"])
            else:
                incorrect.append(input["id"])

        torch.save({"correct": correct, "incorrect": incorrect}, "testrecord/" + self.dataset + str(self.fold) + self.lm_name)

        if self.dataset == "FC":
            acc = balanced_accuracy_score(truth, pred)
        else:
            acc = accuracy_score(truth, pred)
        if self.num_class == 2:
            f1 = f1_score(truth, pred, average = "binary")
        else:
            f1 = f1_score(truth, pred, average = "macro")
        if self.dataset == "RCVP":
            f1 = f1_score(truth, pred, average = "macro")    
        mif1 = f1_score(truth, pred, average = "micro")
        mapre = precision_score(truth, pred, average = "macro")
        marec = recall_score(truth, pred, average = "macro")
        batch_loss /= len(test_batch)
        self.log("test_acc", acc)

        # logging
        f = open(self.log_dir, "a")
        now = datetime.now()
        dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
        if self.dataset == "SemEval" or self.dataset == "Allsides" or self.dataset == "RCVP":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Part: " + str(self.part) + " Time: " + dt_string + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test accuracy: " + str(acc) + " Test F1-score: " + str(f1) + "\n")
            f.write("--------------------\n")
            f.close()
        elif self.dataset == "FND":
            f.write("Dataset: " + self.dataset + " Fold: " + str(self.fold) + " Part: " + str(self.part) + " Time: " + dt_string + "\n")
            f.write("Val accuracy: " + str(self.valmaxacc) + " Val F1-score: " + str(self.valmaxf1) + "\n")
            f.write("Test mif1: " + str(mif1) + " Test maf1: " + str(f1) + " Test mapre: " + str(mapre) + " Test marec: " + str(marec) + "\n")
            f.write("--------------------\n")
            f.close()

def train_once(dataset, fold, hyperparams):
    train_loader, dev_loader, test_loader = dataloader.get_dataloaders_trainpart(dataset, fold, hyperparams["batch_size"], hyperparams["part"])
    model = LM_trainer(dataset, fold, hyperparams)
    trainer = pl.Trainer(gpus = 1, num_nodes = 1, precision = 16, max_epochs = hyperparams["max_epochs"])
    trainer.fit(model, train_loader, dev_loader)
    trainer.test(model, test_loader)

dataset = "FND" # change
hyperparams = {
    "optimizer": "RAdam",
    "lm_name": "longformer", # change
    "in_channels": 128,
    "lr_lm": 1e-5,
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "lm_freeze_epochs": 1000,
    "log_dir": "logs/log_lm5.txt", # change
    "ringlm_layer": 2,
    "nhead": 1,
    "dropout_p": 0.5,
    "batch_size": 16,
    "max_epochs": 10, # change
    "part": 1
}

# for part in [0.3, 0.2, 0.1]:
#     hyperparams["part"] = part
#     #for fold in range(10): # change
#     fold = 0
#     train_once(dataset, fold, hyperparams)
# # part = 1
# # for fold in [1]:
# #     train_once(dataset, fold, hyperparams)

while True:
    fold = 0
    train_once(dataset, fold, hyperparams)