import torch
import torch.nn as nn
import torch.nn.functional as F
from woods.objectives.ERM import ERM


def get_index(num_domain=2):
    index = []
    for i in range(num_domain):
        for j in range(i+1, num_domain+1):
            index.append((i, j))
    return index

class AdaRNN(ERM):
    """
    model_type:  'Boosting', 'AdaRNN'
    """
    def __init__(self, model, dataset, optimizer, hparams):
        super(AdaRNN, self).__init__(model, dataset, optimizer, hparams)
        # Save hparams
        self.device = self.hparams['device']
        self.len_win = self.hparams['len_win']
        self.dw = self.hparams['dw']

        # Save training components
        self.model = model
        self.dataset = dataset
        self.optimizer = optimizer

        # Get some other useful info
        self.nb_training_domains = dataset.get_nb_training_domains()

    def predict(self, all_x):
        return self.model.predict(all_x)

    def update(self):
        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()
        list_feat, list_label = self.dataset.split_tensor_by_domains(X, Y, self.nb_training_domains)
        flag = False
        index = get_index(self.nb_training_domains - 1)
        # print(index)
        for temp_index in index:
            s1 = temp_index[0]
            s2 = temp_index[1]
            # print(list_feat[s1].shape[0], list_feat[s2].shape[0])
            if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
                flag = True
                break
        if flag:
            raise ValueError("The number of samples in different domains is not equal.")
            return

        objective = torch.zeros(1).to(self.device)
        for i in range(len(index)):
            feature_s = list_feat[index[i][0]]
            feature_t = list_feat[index[i][1]]
            label_reg_s = list_label[index[i][0]]
            label_reg_t = list_label[index[i][1]]
            feature_all = torch.cat((feature_s, feature_t), 0)
            pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
                    feature_all, len_win=self.len_win)
            label_all = torch.cat((label_reg_s, label_reg_t)).squeeze(1)

            loss_cls = F.cross_entropy(pred_all, label_all)

            objective += loss_cls + self.dw * loss_transfer

        # Back propagate
        self.optimizer.zero_grad()
        objective.backward()
        torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.)
        self.optimizer.step()

