import torch
from colored import fg, attr
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

def dataloader(state, event):
    
    trainset = event.dataset("train")
    testset = event.dataset("test")
    valset = event.dataset("val")
    valloader = None

    # set seed for all data loader workers
    worker_init = lambda worker_id: event.optional.set_seed()

    # split train in train and val if val_prop != 0 and there is no valset
    if valset is None and state["val_prop"] > 0 and hasattr(event, "validate"):
        state.all["valset_size"] = int(np.floor(state.all["trainset_size"] * state["val_prop"]))
        state.all["trainset_size"] -= state.all["valset_size"]
        trainset, valset = torch.utils.data.random_split(trainset, [state.all["trainset_size"], state.all["valset_size"]])
        print("Splitted %d %% of trainset as valset" %int(state["val_prop"]*100))

    # add surplus of data to validation set instead of throwing away
    surplus = 0
    if state["drop_last"]:
        surplus = state.all["trainset_size"] - state.all["trainset_size"] // state["batchsize"] * state["batchsize"]
        state.all["trainset_size"] -= surplus
        state.all["valset_size"] += surplus

        print("  -> "+attr('bold')+fg('green')+"batches.drop_last"+attr('reset')+" moved %d samples to Valset." % surplus)

    # only use data_subset train_data
    if state["data_subset"] != 0:
        trainset_size = state.all["trainset_size"]
        state.all["trainset_size"] = state["data_subset"]
        trainset, _ = torch.utils.data.random_split(trainset, [state["data_subset"], trainset_size - state["data_subset"] + surplus])
        print("Using subset of %s data points." % state["data_subset"])

        trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=state["batchsize"], num_workers=8, drop_last=state["drop_last"], worker_init_fn=worker_init, pin_memory=True)
        #if S("entropy.init.on",alt=False):
        #    state["trainloader_initializer"] = torch.utils.data.DataLoader(trainset, sampler=sampler, batch_size=S("entropy.init.batch_size"), num_workers=8)
    else:
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=state["batchsize"], sampler=None, shuffle=True, num_workers=8, drop_last=state["drop_last"], worker_init_fn=worker_init, pin_memory=True)


    # if S("entropy.init.on",alt=False):
    #     state["trainloader_initializer"] = torch.utils.data.DataLoader(trainset, batch_size=S("entropy.init.batch_size"), num_workers=8, shuffle=True)

    if state["val_prop"] > 0:
        valloader = torch.utils.data.DataLoader(valset, batch_size=state["batchsize"], sampler=None, shuffle=False, num_workers=2, worker_init_fn=worker_init, pin_memory=True)

    testloader = torch.utils.data.DataLoader(testset, batch_size=state["batchsize"], shuffle=False, num_workers=2, worker_init_fn=worker_init, pin_memory=True)


    print("Trainset size: %d. Valset size: %d. Testset size: %d" %(state.all["trainset_size"], state.all["valset_size"], state.all["testset_size"]))
    return trainloader, testloader, valloader


def register(mf):
    mf.register_defaults({
        "batchsize": 128,
        "testbatchsize": 32,
        "data_subset": 0,
        "val_prop": 0.05,
        "drop_last": True,
    })
    mf.register_event("dataloader", dataloader, unique=True)
