import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import *
import numpy as np
import argparse
from itertools import product
from torch.utils.data import DataLoader
from transformers.modeling_outputs import SequenceClassifierOutput

def main(args):
    test_hard_accs, test_accs = [], []
    data_dir = os.path.abspath("ranking_datasets")
    for run in range(args.nruns):
        model, optimizer = load_model(args)

        tasks = args.tasks
        train_sets, test_loaders, hard_test_loaders = [], [], []
        for task in tasks:
            train_sets.append(load_process_data(args, data_dir, task, "train", args.model))
            test_data = load_process_data(args, data_dir, task, "test", args.model)
            test_loaders.append(DataLoader(test_data, batch_size=args.batch_size // 2, shuffle=False))
            if task != 'power':
                test_hard_data = load_process_data(args, data_dir, task, "test_hard", args.model)
                hard_test_loaders.append(DataLoader(test_hard_data, batch_size=args.batch_size // 2, shuffle=False))
            else:
                hard_test_loaders.append([])

        train_data = torch.utils.data.ConcatDataset(train_sets)
        train_dataloader = DataLoader(train_data, batch_size=args.batch_size // 2, shuffle=True)

        for epoch in range(1, args.nepochs + 1):
            print('Epoch', epoch)
            train(model, optimizer, train_dataloader, epoch, verbose=args.verbose)
            for task, loader, hard_loader in zip(tasks, test_loaders, hard_test_loaders):
                print(f"{task} acc")

                test_acc = evaluate(model, loader)
                if task != 'power':
                    print(f"{task} hard acc")
                    test_hard_acc = evaluate(model, hard_loader)

            if args.save:
                model_suffix = args.model.split("/")[-1] # discard organization and slash
                save_path = "multi_{}_{}_{}_{}.pkl".format(model_suffix, args.learning_rate, args.batch_size, epoch)
                print("SAVING to", save_path)
                torch.save(model.module.state_dict(), save_path)

    return 0,0

def flatten(tensor):
    tensor = torch.cat([tensor[:, 0], tensor[:, 1]])
    return tensor

def unflatten(tensor):
    tensor = torch.stack([tensor[:tensor.shape[0] // 2], tensor[tensor.shape[0] // 2:]], axis=1)
    return tensor

def train(model, optimizer, train_dataloader, epoch, log_interval=100, verbose=False):
    # Set model to training mode
    model.train()
    criterion = torch.nn.BCEWithLogitsLoss()
    ntrain_steps = len(train_dataloader)

    # Loop over each batch from the training set
    for step, batch in enumerate(train_dataloader):

        # Copy data to GPU if needed
        batch = tuple(t.cuda() for t in batch)

        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        # reshape
        b_input_ids = flatten(b_input_ids)
        b_input_mask = flatten(b_input_mask)

        # Zero gradient buffers
        optimizer.zero_grad()

        # Forward pass
        output = model(b_input_ids, attention_mask=b_input_mask)
        if isinstance(output, SequenceClassifierOutput): # this is for deberta, etc. outputs only; not flan
            output = model(b_input_ids, attention_mask=b_input_mask)[0]  # dim 1
        output = unflatten(output)
        diffs = output[:, 0] - output[:, 1]
        
        # optional
        diffs = torch.masked_select(diffs, b_labels.bool()[:,0]).unsqueeze(-1)

        loss = criterion(diffs.squeeze(dim=1), torch.ones(diffs.shape[0]).cuda())

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

        if step % log_interval == 0 and step > 0 and verbose:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, step, ntrain_steps, 100. * step / ntrain_steps, loss))

def evaluate(model, dataloader):
    model.eval()
    cors = []

    for step, batch in enumerate(dataloader):
        # Copy data to GPU if needed
        batch = tuple(t.cuda() for t in batch)

        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        # reshape
        b_input_ids = flatten(b_input_ids)
        b_input_mask = flatten(b_input_mask)

        # Forward pass
        with torch.no_grad():
            output = model(b_input_ids, attention_mask=b_input_mask)
            if isinstance(output, SequenceClassifierOutput): # this is for deberta, etc. outputs only; not flan
                output = output[0]  # dim 1
            output = unflatten(output)
            diffs = output[:, 0] - output[:, 1]
            
            # optional
            diffs = torch.masked_select(diffs, b_labels.bool()[:,0]).unsqueeze(-1)

            diffs = diffs.squeeze(dim=1).detach().cpu().numpy()
        cors.append(diffs > 0)

    cors = np.concatenate(cors)
    acc = np.mean(cors)

    print('Acc {:.3f}'.format(acc))
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="bert-base-uncased")
    parser.add_argument("--tasks", "-ta", nargs="+", default=["paired_cm", "util", "power"])
    parser.add_argument("--ngpus", "-n", type=int, default=2)
    parser.add_argument("--nepochs", "-e", type=int, default=2)
    parser.add_argument("--batch_size", "-b", type=int, default=16)
    parser.add_argument("--max_length", "-t", type=int, default=64)
    parser.add_argument("--weight_decay", "-w", type=float, default=0.01)
    parser.add_argument("--learning_rate", "-l", type=float, default=1e-5)
    parser.add_argument("--verbose", "-v", action="store_true")
    parser.add_argument("--nruns", "-r", type=int, default=1)
    parser.add_argument("--grid_search", "-g", action="store_true")
    parser.add_argument("--save", "-s", action="store_true")
    parser.add_argument("--dropout", "-d", type=float, default=0.0)
    args = parser.parse_args()
    if 'custom_tokens' in args:
        try:
            args.custom_tokens = [int(x) for x in args.custom_tokens]
        except:
            raise ValueError("custom tokens must be integers")
    

    if args.grid_search:
        file = "grid_search_results.txt"
        args.nruns = 1
        models = ["bert-base-uncased", "bert-large-uncased", "roberta-large", "albert-xxlarge-v2"]
        lrs = [1e-5, 3e-5]
        batch_sizes = [8, 16]
        epochs = [2,4]

        with open(file, "a") as f:
            f.write("{}\n".format(args))
            f.write("models: {}, lrs: {}, batch_sizes: {}, epochs: {}\n".format(models, lrs, batch_sizes, epochs))

        for model, lr, bs, nepoch in product(models, lrs, batch_sizes, epochs):
            args.model = model
            args.learning_rate = lr
            args.batch_size = bs
            args.nepochs = nepoch

            test_hard_acc, test_acc = main(args)

            with open(file, "a") as f:
                f.write("model: {}, lr: {}, batch_size: {}, nepoch: {}.\n test hard accuracy: {:.3f}, test accuracy: {:.3f}\n".format(model, lr, bs, nepoch, test_hard_acc, test_acc))
    else:
        main(args)




