import torch 
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import os
import sys

def GetGradNorm(model):
    return torch.sqrt(sum(p.grad.norm(2)**2 for p in model.parameters() if p.requires_grad))

def DefineDevice(device):
    if device is None:
        if torch.cuda.is_available():
            return "cuda"
        else:
            return "cpu"
    else:
        return device
        
def OneHotEncode(Y, set_Y, device=None):
    one_hot = torch.zeros(Y.shape[0], set_Y.shape[0]).to(device)
    one_hot[torch.arange(Y.shape[0]), Y.long()] = 1
    return one_hot.double()

## Specifically for the experiments ##
def FindRowIndex(X, x):
    x = x.view(1, -1)
    for i in range(X.shape[0]):
        if torch.all(torch.eq(X[i], x)):
            return i 
    return -1

def GetP_Y_Z(Y, Z, set_Y, set_Z):
    assert set_Y.tolist()==[0,1], "Function designed to be used when set_Y = [0,1]"
    
    P_Y_Z = torch.zeros((set_Z.shape[0], set_Y.shape[0]))
    for z in set_Z:
        ind = Z == z
        if ind.float().sum()==0:
            p=.5
        else:
            p = Y[ind].float().mean().item()    
        P_Y_Z[z][0] += 1-p
        P_Y_Z[z][1] += p
    P_Y_Z = P_Y_Z.T
    return P_Y_Z


class SuppressPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        sys.stdout = open(os.devnull, 'w')
        sys.stderr = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stderr.close()
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr