import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from codes.components.utils import log_dict

import numpy as np
from torch.utils.data.dataset import Dataset
import math


class SimpleModel(nn.Module):
    def __init__(self, init_coords=(0, 0)):
        super(SimpleModel, self).__init__()
        self.coords = nn.Parameter(torch.tensor(init_coords).float())

    def forward(self, x=None):
        return self.coords

def read_txt(path):
    aps, sgs, xs, ys = [], [], [], []
    with open(path, 'r') as f:
        for sample in f.readlines():
            ap, sg, x, y = sample.strip().split(" ")
            aps.append(float(ap))
            sgs.append(float(sg))
            xs.append(float(x))
            ys.append(float(y))
    return aps, sgs, xs, ys


class Ackley(nn.Module):
    def __init__(self, a=20, b=0.2, c=2):
        super(Ackley, self).__init__()
        self.a = a
        self.b = b
        self.c = c

    def forward(self, xy):
        x, y = xy
        a = self.a
        b = self.b
        c = self.c * np.pi

        sum1 = x ** 2 + y ** 2
        sum2 = torch.cos(c * x) + torch.cos(c * y)
        term1 = -a * torch.exp(-b * torch.sqrt(sum1 / 2))
        term2 = -torch.exp(sum2 / 2)
        return term1 + term2 + a + np.exp(1)


class CombinedGaussian(nn.Module):
    def __init__(self, sigma=0.7):
        super(CombinedGaussian, self).__init__()

    def forward(self, xy,
                centers=torch.tensor([(1, 0.7, 2, 2), (1, 0.7, -2, -2),
                                      (1, 0.7, -2, 2), (1, 0.7, 2, -2),
                                      (1, 0.7, 0, 0), (1, 0.7, 0, -3)])
                ):
        x, y = xy
        value = 0
        for center in centers:
            a, sigma, c_x, c_y = center
            value = a * torch.exp(-((x - c_x) ** 2 + (y - c_y) ** 2) / (2 * sigma ** 2))
        return 1 - value


class TwoMinima(nn.Module):
    def __init__(self):
        super(TwoMinima, self).__init__()

    def forward(self, xy, theta=torch.tensor([(1, 0, 1, 0)])):
        x, y = xy
        a, sigma, theta1, theta2 = theta[0]
        value = 1 - a * torch.exp(-((theta1 - x) ** 2 + (theta2 - y) ** 2))
        return value


class ToyDataset(Dataset):

    def __init__(self, txtpath, transform=None):

        super().__init__()
        self.aps, self.sgs, self.xs, self.ys = read_txt(txtpath)
        self.numbers = len(self.xs)
        self.transform = transform

    def __getitem__(self, index):
        ap = self.aps[index]
        sg = self.sgs[index]
        x = self.xs[index]
        y = self.ys[index]

        z = torch.from_numpy(np.asarray([ap, sg, x, y], dtype=np.float32))

        return z

    def __len__(self):
        return len(self.xs)


def toy_data(
    data_dir,
    train,
    download,
    batch_size,
    shuffle=None,
    sampler_callback=None,
    dataset_cls=ToyDataset,
    drop_last=True,
    **loader_kwargs
):
    # if sampler_callback is not None and shuffle is not None:
    #     raise ValueError

    dataset = dataset_cls(
        data_dir,
        transform=None,
    )

    sampler = sampler_callback(dataset) if sampler_callback else None
    log_dict(
        {
            "Type": "Setup",
            "Dataset": "toydata",
            "data_dir": data_dir,
            "train": train,
            "download": download,
            "batch_size": batch_size,
            "shuffle": shuffle,
            "sampler": sampler.__str__() if sampler else None,
        }
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        drop_last=drop_last,
        **loader_kwargs,
    )
