import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from codes.components.utils import log_dict

import numpy as np
from torch.utils.data.dataset import Dataset


class ToyNet(nn.Module):
    def __init__(self):
        super(ToyNet, self).__init__()
        self.x = torch.nn.Parameter(torch.zeros(1))
        self.x.requires_grad = True
        self.y = torch.nn.Parameter(torch.zeros(1))
        self.y.requires_grad = True

    def forward(self, inputs):
        return torch.unsqueeze(torch.cat([self.x, self.y], 0), 0)

def read_txt(path):
    xs, ys = [], []
    with open(path, 'r') as f:
        for sample in f.readlines():
            x, y = sample.strip().split(" ")
            xs.append(x)
            ys.append(y)
    return xs, ys


class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, x, y):
        mse_loss = torch.mean(torch.pow((x - y), 2))
        return mse_loss


class ToyDataset(Dataset):

    def __init__(self, txtpath, transform=None):

        super().__init__()
        self.xs, self.ys = read_txt(txtpath)
        self.numbers = len(self.xs)
        self.transform = transform

    def __getitem__(self, index):
        x = self.xs[index]
        y = self.ys[index]

        z = torch.from_numpy(np.asarray([x,y], dtype=float))

        return z

    def __len__(self):
        return len(self.xs)

def toy_data(
    data_dir,
    train,
    download,
    batch_size,
    shuffle=None,
    sampler_callback=None,
    dataset_cls=ToyDataset,
    drop_last=True,
    **loader_kwargs
):
    # if sampler_callback is not None and shuffle is not None:
    #     raise ValueError

    dataset = dataset_cls(
        data_dir,
        transform=None,
    )

    sampler = sampler_callback(dataset) if sampler_callback else None
    log_dict(
        {
            "Type": "Setup",
            "Dataset": "toydata",
            "data_dir": data_dir,
            "train": train,
            "download": download,
            "batch_size": batch_size,
            "shuffle": shuffle,
            "sampler": sampler.__str__() if sampler else None,
        }
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        drop_last=drop_last,
        **loader_kwargs,
    )
