import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

import transformers
from model import load_nlvr_model
from data import NLVRLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm

def train_nlvr(args, train_dataloader, val_dataloader, test_dataloader, model):
    loss_func = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr = args.lr)
    if torch.cuda.is_available():
        model = model.cuda()
    
    epoch_train_losses = []
    epoch_val_losses = []
    epoch_test_acc = []
    for epoch in range(args.num_epochs):
        model = model.train()
        train_loss = 0.0
        train_size = 0
        for train_batch in tqdm(train_dataloader, desc = 'Iterating over training batch'):
            img1, img2, input_ids, sentence, label = train_batch
            if torch.cuda.is_available():
                img1 = img1.cuda()
                img2 = img2.cuda()
                input_ids = input_ids.cuda()
                label = label.cuda()
            logits = model(img1, img2, input_ids)
            loss = loss_func(logits, label)
            scalar_loss = loss.detach().item()
            loss.backward()
            optim.step()
            optim.zero_grad()
            train_loss += args.batch_size * scalar_loss
            train_size += args.batch_size
        epoch_train_losses.append(train_loss/train_size)

        model = model.eval()
        val_loss = 0.0
        val_size = 0
        with torch.no_grad():
            for val_batch in tqdm(val_dataloader, desc = 'Iterating over validation batch'):
                img1, img2, input_ids, sentence, label = val_batch
                if torch.cuda.is_available():
                    img1 = img1.cuda()
                    img2 = img2.cuda()
                    input_ids = input_ids.cuda()
                    label = label.cuda()
                logits = model(img1, img2, input_ids)
                loss = loss_func(logits, label)
                scalar_loss = loss.detach().item()
                val_loss += args.batch_size * scalar_loss
                val_size += args.batch_size
            epoch_val_losses.append(val_loss/val_size)

        acc, total = 0, 0
        with torch.no_grad():
            for test_batch in tqdm(test_dataloader, desc = 'Iterating over test batch'):
                img1, img2, input_ids, sentence, label = test_batch
                if torch.cuda.is_available():
                    img1 = img1.cuda()
                    img2 = img2.cuda()
                    input_ids = input_ids.cuda()
                    label = label.cuda()
                logits = model(img1, img2, input_ids)
                preds = torch.argmax(logits, dim = -1)
                acc += sum(torch.eq(preds, label))
                total += preds.shape[0]
        epoch_test_acc.append(acc/total)
        if args.save_ckpt:
            torch.save(model.fusion_encoder.state_dict(), os.path.join(args.save_path, args.exp_name, f'{epoch}.pt'))
        
        print(f'Training loss on epoch {epoch}: {train_loss/train_size}')
        print(f'Validation loss on epoch {epoch}: {val_loss/val_size}')
        print(f'Testing accuracy on epoch {epoch}: {acc/total}')

    return model, epoch_train_losses, epoch_val_losses, epoch_test_acc

def run_train(args):
    model = load_nlvr_model(args, args.text_model_str)
    #model = nn.DataParallel(model)
    train_loader, val_loader, test_loader = NLVRLoader(args)
    return train_nlvr(args, train_loader, val_loader, test_loader, model)
