import util
import torch
import torch.optim as optim

class trainer():
    def __init__(self, device, model, scaler, lrate, wdecay, clip):
        self.model = model
        self.model.to(device)
        
        self.clip = clip
        self.scaler = scaler
        self.loss = util.masked_mae
        self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)


    def train(self, input, real_val, idx):
        self.model.train()
        self.optimizer.zero_grad()

        output = self.model(input, idx).transpose(1,3)
        predict = self.scaler.inverse_transform(output)
        loss = self.loss(predict, real_val, 0.0)
        mape = util.masked_mape(predict, real_val, 0.0).item()
        rmse = util.masked_rmse(predict, real_val, 0.0).item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
        self.optimizer.step()
        return loss.item(), mape, rmse

    
    def eval(self, input, real_val, idx):
        self.model.eval()
        with torch.no_grad():
            output = self.model(input, idx).transpose(1,3)
            predict = self.scaler.inverse_transform(output)
            loss = self.loss(predict, real_val, 0.0)
            mape = util.masked_mape(predict, real_val, 0.0).item()
            rmse = util.masked_rmse(predict, real_val, 0.0).item()
            return loss.item(), mape, rmse
