import wandb
import plot
import torch

from models.LinearModel import get_error

from torch.utils.data import DataLoader


def ks_setup(dataset, batch_size):
    train_set, test_set = dataset.train_test_split(test_ratio=0.1)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)  # TODO: Segregate test set properly
    return train_loader, test_loader

def save_real_image(test_loader, train_type):
    test_sample, test_label = next(iter(test_loader))
    test_sample = test_sample[0].unsqueeze(0)
    test_label = test_label[0].unsqueeze(0)
    real_image = wandb.Image(plot.to_image(test_label), caption=f'Real Image')
    wandb.log({'real_image': real_image})
    wandb.log({'train_type': train_type})
    return test_sample, test_label

def split_test_train(dataset, labels, test_ratio=0.1):
    test_size = int(len(dataset) * test_ratio)
    test_set = dataset[:test_size]
    train_set = dataset[test_size:]
    test_labels = labels[:test_size]
    train_labels = labels[test_size:]
    return train_set, test_set, train_labels, test_labels

def save_generated_image(model, test_sample, test_label, i):
    model_output = model(test_sample)
    sample = plot.to_image(model_output)
    diff = model_output - test_label
    diff_sample = plot.to_image(diff)
    gen_image = wandb.Image(sample, caption=f'Generated, Epoch {i}')
    diff_image = wandb.Image(diff_sample, caption=f'Difference, Epoch {i}')
    wandb.log({'gen_image': gen_image, 'diff_image': diff_image})
    wandb.log({'max_resid': torch.abs(diff).max().item()})

def get_res_err(x, y):
    x_np, y_np = x.cpu().detach().numpy(), y.cpu().detach().numpy()
    err = get_error(x_np, y_np)
    res = (x_np - y_np)
    res_min, res_max = res.min(), res.max()
    print('res: ', res_min, res_max)
    print('err: ', err)

def define_metrics():
    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")

    wandb.define_metric("test/step")
    wandb.define_metric("test/*", step_metric="test/step")
    return 0, 0