import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data as Data
from torch.autograd import Variable,grad
import random
import time
import os
torch.set_default_dtype(torch.float64)

class FNN(nn.Module):
    def __init__(self, dim_vec):
        super(FNN, self).__init__()
        layers = []
        for i in range(len(dim_vec) - 1):
            layers.append(nn.Linear(dim_vec[i], dim_vec[i+1]))
            if i < len(dim_vec) - 2:
                layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)
        self.initialize_weights()

    def forward(self, x):
        return self.network(x)
    
    def initialize_weights(self):
        for layer in self.network:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
                nn.init.constant_(layer.bias, 0)


def generate_synthetic_dataset(n, d,model,true_weights,seed):
    torch.manual_seed(seed)
    if model == 'bt': # Bradley-Terry model

        state_features = torch.rand((n, d))
        state_features2 = state_features
        for k in range(d):
            state_features2[:,k] = torch.sin(state_features[:,k])
        r1 = 2*torch.sin(4*state_features2 @ true_weights)
        r0 = -2*torch.sin(4*state_features2 @ true_weights)
        
        probabilities = torch.sigmoid(r1 - r0) #sine function

    elif model == 'thurstonian':

        state_features = torch.rand((n, d))
        state_features2 = state_features
        for k in range(d):
            state_features2[:,k] = torch.sin(state_features[:,k])
        r1 = 2*torch.sin(4*state_features2 @ true_weights)
        r0 = -2*torch.sin(4*state_features2 @ true_weights)
        normal_dist = torch.distributions.Normal(0, 1)
        probabilities = normal_dist.cdf(r1-r0) #sine function
    else:
        raise ValueError("Model must be 'bt' or 'thurstonian'")
    preferences = torch.bernoulli(probabilities)
    dataset = TensorDataset(state_features, preferences)
    return dataset

