
import torch
import torch.nn as nn
from torch import optim
from torch.nn import functional as F
from woods.objectives.ERM import ERM


class FEDNet(ERM):
    def __init__(self, model, dataset, optimizer, hparams):
        super(FEDNet, self).__init__(model, dataset, optimizer, hparams)

        # Save training components
        self.model = model
        self.dataset = dataset
        self.optimizer = optimizer

        
        aux_optim_learning_rate = hparams['aux_optim_learning_rate'] # 1e-3
        self.aux_optim = optim.Adam(self.model.parameters(), lr=aux_optim_learning_rate)
        self.nb_training_domains = self.dataset.get_nb_training_domains()

    def predict(self, all_x):
        self.model.eval()
        return self.model(all_x, y=None, d=None)

    def update(self):
        self.model.train()
        X, Y = self.dataset.get_next_batch()
        B = X.shape[0]
        assert B % self.nb_training_domains == 0
        # assign domain for X
        d = torch.tensor([i for i in range(self.nb_training_domains) for _ in range(B//self.nb_training_domains)]).to(self.device)
        y_hat, loss_train, loss_con = self.model(X, Y, d)
        
        self.optimizer.zero_grad()
        loss_train.backward()
        self.optimizer.step()
        
        if loss_con is not None:
            self.aux_optim.zero_grad()
            loss_con.backward()
            self.aux_optim.step()

