import torch
import numpy as np
from torch import autograd as ag
from sklearn.preprocessing import PowerTransformer

def totorch(x,device):
    if type(x) is tuple:
        return tuple([ag.Variable(torch.Tensor(e)).to(device) for e in x])
    return ag.Variable(torch.Tensor(x)).to(device)
    
def prepare_data(indexes,support,Lambda,response,metafeatures,output_transform=False,pairwise=False):
    # Generate indexes of the batch
    X,E,Z,y,r = [],[],[],[],[]
    #### get support data
    for dim in indexes:
        Z.append(metafeatures)
        E.append(Lambda[support])
        X.append(Lambda[dim])
        r_ = response[support,np.newaxis]
        y_ = response[dim]
        if output_transform:
            power = PowerTransformer(method="yeo-johnson")
            r_ = power.fit_transform(r_)
            y_ = power.transform(y_.reshape(-1,1)).reshape(-1,)
        r.append(r_)
        y.append(y_)
    X = np.array(X);E = np.array(E);Z = np.array(Z);y = np.array(y);r = np.array(r)
    if not pairwise:
        return (np.expand_dims(E,axis=-1),r,np.expand_dims(X,axis=-1),Z), y
    else:
        pairs = []
        for i in range(len(support)):
            for j in range(i+1,len(support)):
                pairs.append(np.concatenate([E[:,i],r[:,i],E[:,j],r[:,j]],-1))
        return (np.expand_dims(E,axis=-1),r,np.expand_dims(X,axis=-1),Z,np.stack(pairs,1)), y

class Metric(object):
    def __init__(self,prefix='train: '):
        self.reset()
        self.message=prefix + "loss: {loss:.2f} - noise: {log_var:.2f} - mse: {mse:.2f}"
        
    def update(self,loss,noise,mse):
        self.loss.append(np.asscalar(loss))
        self.noise.append(np.asscalar(noise))
        self.mse.append(np.asscalar(mse))
    
    def reset(self,):
        self.loss = []
        self.noise = []
        self.mse = []
    
    def report(self):
        return self.message.format(loss=np.mean(self.loss),
                            log_var=np.mean(self.noise),
                            mse=np.mean(self.mse))
    
    def get(self):
        return {"loss":np.mean(self.loss),
                "noise":np.mean(self.noise),
                "mse":np.mean(self.mse)}
