import torch
from logistic_regression_data import get_caravan_lr_data, get_ripley_lr_data, get_germancredit_lr_data, get_australiancredit_lr_data, get_heart_lr_data
import numpy as np
class CaravanLR():
    def __init__(self,filename='./Datasets/caravan-insurance-challenge.csv', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_caravan_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0) + 1e-10
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0
        
class RipleyLR():
    def __init__(self,filename='./Datasets/ripley.csv', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_ripley_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0)
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0


class GermanCreditLR():
    def __init__(self,filename='./Datasets/SouthGermanCredit.asc', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_germancredit_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0)
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0

class AustralianCreditLR():
    def __init__(self,filename='./Datasets/australian.dat', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_australiancredit_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0)
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0
   
class HeartLR():
    def __init__(self,filename='./Datasets/heart.csv', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_heart_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0)
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0
   

class PimaLR():
    def __init__(self,filename='./Datasets/diabetes.csv', device=torch.device('cpu'), prior_std=10.0):
        self.device=device
        data, outcomes = get_heart_lr_data(filename)
        data -= np.mean(data, axis=0)
        data /= np.std(data, axis=0)
        self.predictors = np.ones([data.shape[0], data.shape[1] + 1])
        self.predictors[:,1:] = data
        self.outcomes = outcomes
        self.predictors = torch.from_numpy(self.predictors).to(device)
        self.outcome_vectors = torch.ones([self.outcomes.shape[0], 2], device=self.device)
        for i in range(len(self.outcome_vectors)):
            self.outcome_vectors[i][int(outcomes[i])] = 0.0
        
        self.outcomes = torch.from_numpy(self.outcomes).to(device)
        self.prior = prior_std

    def log_prob(self, x):
        if len(x.shape) < 2:
            starts = x.unsqueeze(0)[:, 0:self.predictors.shape[1]]
            aux = x.unsqueeze(0)[:, self.predictors.shape[1]:]
        else:
            starts = x[:, 0:self.predictors.shape[1]]
            aux = x[:, self.predictors.shape[1]:]
            
        weights = torch.ones([starts.shape[0],self.predictors.shape[0], 2], device=self.device)
        weights[:,:,1] = -(starts.unsqueeze(1) * self.predictors.unsqueeze(0)).sum(dim=-1)
        weights = torch.nn.functional.log_softmax(weights, dim=-1)
        return (weights * self.outcome_vectors.unsqueeze(0)).sum(dim=-1).sum(dim=-1) - (starts**2).sum(dim=-1)/self.prior - ((aux)**2.0).sum(dim=-1)/2.0
   
