import torch
import torch.nn as nn
import torch.nn.functional as F
from alg.algs.ERM import ERM
from network.adarnn_network import AdaRNN_Network


def get_index(num_domain=2):
    index = []
    for i in range(num_domain):
        for j in range(i+1, num_domain+1):
            index.append((i, j))
    return index

class AdaRNN(ERM):
    """
    model_type:  'Boosting', 'AdaRNN'
    """
    def __init__(self, args):
        super(AdaRNN, self).__init__(args)
        # Save hparams
        # Save training components
        self.model = AdaRNN_Network(args)
        self.args=args
        self.len_win = args.len_win
        self.dw = args.dw
        

    def predict(self, all_x):
        self.model.eval()
        all_x = all_x.squeeze(2).permute(0, 2, 1)
        return self.model.predict(all_x)

    def update(self, minibatches, opt,sch):
        # Put model into training mode
        self.model.train()
        flag = False
        index = get_index(len(minibatches) - 1)

        list_feat = [minibatches[i][0].cuda().float() for i in range(len(minibatches))]
        list_label = [minibatches[i][1].cuda().long() for i in range(len(minibatches))]

        for temp_index in index:
            s1 = temp_index[0]
            s2 = temp_index[1]
            if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
                flag = True
                break
        
        if flag:
            raise ValueError("The number of samples in different domains is not equal.")
            return

        loss = torch.zeros(1).to(self.device)
        total_loss_cls = torch.zeros(1).to(self.device)
        total_loss_transfer = torch.zeros(1).to(self.device)
        for i in range(len(index)):
            feature_s = list_feat[index[i][0]]
            feature_t = list_feat[index[i][1]]
            label_reg_s = list_label[index[i][0]]
            label_reg_t = list_label[index[i][1]]
            feature_all = torch.cat((feature_s, feature_t), 0)
            feature_all = feature_all.squeeze(2).permute(0, 2, 1)
            pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
                    feature_all, len_win=self.len_win)
            label_all = torch.cat((label_reg_s, label_reg_t))
            loss_cls = F.cross_entropy(pred_all, label_all)

            loss += loss_cls + self.dw * loss_transfer
            total_loss_cls += loss_cls
            total_loss_transfer += loss_transfer

        # Back propagate
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.)
        opt.step()
        if sch:
            sch.step()
        
        return {'total': loss.item(),
                'class': total_loss_cls.item(),
                'trans': self.dw * total_loss_transfer.item()}

