import argparse
from typing import Union, Tuple, Dict

import json
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

from transformers import AutoTokenizer

from tqdm import tqdm
import wandb

from model import AutoModelForFEVER

class FTDataset(Dataset):

    def __init__(
        self,
        train_path: str,
        valid_path: str,
        tok: AutoTokenizer,
        device: Union[str, int, torch.device]
    ):

        with open(train_path) as file:
            self.data = json.load(file)
        with open(valid_path) as file:
            self.data += json.load(file)
        self.tok = tok
        self.device = device  

    def __getitem__(self, idx):
        row = self.data[idx]

        tuples = self.tok(
            row["prompt"],
            return_tensors = "pt"
        )

        tuples["labels"] = torch.FloatTensor([[row["ans"] == "SUPPORTS"]])

        return tuples

    def __len__(self):
        return len(self.data)
    
    def collate_fn(
        self,
        tuples: Tuple[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:
        
        return {
            k: pad_sequence(
                [t[k].squeeze(0) for t in tuples],
                batch_first = True
            ).to(self.device)
            for k in tuples[0].keys()
        }

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--model-name-or-path", type = str, default = "/Users/tanchenmien/model/bert")
    parser.add_argument("--train-data-path", type = str, default = "/Users/tanchenmien/data/fever/fever_train.json")
    parser.add_argument("--valid-data-path", type = str, default = "/Users/tanchenmien/data/fever/fever_eval.json")
    parser.add_argument("--weight-path", type = str, default = "/Users/tanchenmien/model/bert.pth")

    parser.add_argument("--n-epochs", type = int, default = 2)
    parser.add_argument("--batch-size", type = int, default = 64)
    parser.add_argument("--lr", type = float, default = 3e-5)


    args = parser.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    wandb.init(project = "ft")

    tok = AutoTokenizer.from_pretrained(args.model_name_or_path)
    dataset = FTDataset(
        args.train_data_path,
        args.valid_data_path,
        tok,
        device
    )
    loader = DataLoader(
        dataset,
        args.batch_size,
        True,
        collate_fn = dataset.collate_fn
    )
    model = AutoModelForFEVER(args.model_name_or_path).to(device)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr = args.lr
    )

    best_score = None
    for _ in range(args.n_epochs):
        for tuples in tqdm(loader, ncols = 50):
            logits = model(**tuples)["logits"]
            loss = F.binary_cross_entropy_with_logits(logits, tuples["labels"])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            wandb.log({"loss": loss.item()})

        corrs = []
        for tuples in tqdm(loader, ncols = 50):

            with torch.no_grad():
                logits = model(**tuples)["logits"]
            corrs += ((logits > 0) == tuples["labels"]).squeeze(-1).to("cpu").numpy().tolist()
        score = np.mean(corrs)
    
        if best_score is None or score > best_score:
            best_score = score
            torch.save(model.state_dict(), args.weight_path)

        wandb.log({"ES": score})