import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import os


def set_seed(seed=666):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class MyDataset(Dataset):
    def __init__(self, X, Y, device=None):
        super(MyDataset, self).__init__()

        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.device = device
        if X[0].ndim == 2:
            self.data = torch.from_numpy(np.expand_dims(X, axis=1)).float().to(self.device)
        else:
            self.data = torch.from_numpy(X).float().to(self.device)
        self.labels = torch.from_numpy(Y).float().to(self.device)

    def __getitem__(self, index):
        subject, label = self.data[index], self.labels[index]
        return subject, label

    def __len__(self):
        return len(self.labels)


class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='checkpoint.pt', verbose=False):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.verbose = verbose
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)
        if self.verbose:
            print(f"Validation loss decreased. Saving model to {self.path}")


class FCN(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super().__init__()
        self.linear = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, int(input_dim / 2)),
            nn.BatchNorm1d(int(input_dim / 2)),
            nn.ReLU(inplace=True),
            nn.Linear(int(input_dim / 2), output_dim),
        )

    def forward(self, x):
        x = self.linear(x)
        return x


class FCN_for_fitting_match(nn.Module):
    def __init__(self, input_shape=(28, 28), out_channel=1, output_dim=1):
        super().__init__()
        self.beta_layer = nn.Sequential(
            nn.Conv2d(1, out_channel, 4, stride=4, bias=False),
        )
        self.flatten = nn.Flatten()
        input_dim = self.beta_layer(torch.zeros(1, 1, *input_shape)).flatten().shape[0]
        self.linear = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, int(input_dim / 2)),
            nn.BatchNorm1d(int(input_dim / 2)),
            nn.ReLU(inplace=True),
            nn.Linear(int(input_dim / 2), output_dim),
        )

    def forward(self, x):
        x = self.beta_layer(x)
        x = self.linear(self.flatten(x))
        return x


class FCN_for_fitting_mismatch(nn.Module):
    def __init__(self, input_shape=(28, 28), out_channel=1, output_dim=1):
        super().__init__()
        self.beta_layer = nn.Sequential(
            nn.Conv2d(1, out_channel, 4, stride=4, bias=False),
        )
        self.flatten = nn.Flatten()
        input_dim = self.beta_layer(torch.zeros(1, 1, *input_shape)).flatten().shape[0]
        self.linear = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x):
        x = self.beta_layer(x)
        x = self.linear(self.flatten(x))
        return x


# #################### CNN ####################
class CNN(nn.Module):
    def __init__(self, input_shape=(7, 7), in_channel=1, output_dim=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, 10, 4),
            nn.BatchNorm2d(10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        self.flatten = nn.Flatten()
        input_dim = self.conv(torch.zeros(1, in_channel, *input_shape)).flatten().shape[0]
        self.linear = nn.Sequential(
            nn.Linear(input_dim, int(input_dim / 2)),
            nn.BatchNorm1d(int(input_dim / 2)),
            nn.ReLU(inplace=True),
            nn.Linear(int(input_dim / 2), output_dim),
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.linear(self.flatten(x))
        return x


class CNN_for_fitting_match(nn.Module):
    def __init__(self, input_shape=(28, 28), out_channel=1, output_dim=1):
        super().__init__()
        self.beta_layer = nn.Sequential(
            nn.Conv2d(1, out_channel, 4, stride=4, bias=False),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(out_channel, 10, 4),
            nn.BatchNorm2d(10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        self.flatten = nn.Flatten()
        input_dim = self.conv(self.beta_layer(torch.zeros(1, 1, *input_shape))).flatten().shape[0]
        self.linear = nn.Sequential(
            nn.Linear(input_dim, int(input_dim / 2)),
            nn.BatchNorm1d(int(input_dim / 2)),
            nn.ReLU(inplace=True),
            nn.Linear(int(input_dim / 2), output_dim),
        )

    def forward(self, x):
        x = self.beta_layer(x)
        x = self.conv(x)
        x = self.linear(self.flatten(x))
        return x


class CNN_for_fitting_mismatch(nn.Module):
    def __init__(self, input_shape=(28, 28), out_channel=1, output_dim=1):
        super().__init__()
        self.beta_layer = nn.Sequential(
            nn.Conv2d(1, out_channel, 4, stride=4, bias=False),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(out_channel, 10, 4),
            nn.BatchNorm2d(10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        self.flatten = nn.Flatten()
        input_dim = self.conv(self.beta_layer(torch.zeros(1, 1, *input_shape))).flatten().shape[0]
        self.linear = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
            nn.ReLU(inplace=True),
            nn.Linear(input_dim, output_dim),
        )

    def forward(self, x):
        x = self.beta_layer(x)
        x = self.conv(x)
        x = self.linear(self.flatten(x))
        return x

