import warnings
from collections import OrderedDict
import copy
import numpy as np

import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.jit.annotations import Tuple, List, Dict, Optional
import torch.autograd as autograd
from torchvision.ops import MultiScaleRoIAlign

from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork
from torchvision.models.detection.transform import GeneralizedRCNNTransform

from utils.roi_header_util import RoIHeads


class FasterRCNNBase(nn.Module):
    """
    Main class for Generalized R-CNN.

    Arguments:
        backbone (nn.Module):
        rpn (nn.Module):
        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """

    def __init__(self, backbone, rpn, roi_heads, transform, algorithm):
        super(FasterRCNNBase, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
        self.algorithm = algorithm
        self.update_count = 0

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        if self.training:
            return losses

        return detections

    def extract_features(self, images, targets=None, transform=True, is_feature_hook=False):
        if transform:
            images, targets = self.transform(images, targets)

        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = OrderedDict([('0', features)])

        proposals, proposal_losses = self.rpn(images, features, targets)
        
        if is_feature_hook:
            detections, detector_losses, class_logits, labels, box_features = self.roi_heads(features, proposals, images.image_sizes, 
                                                                                             targets, is_feature_hook)
            return detections, proposal_losses, detector_losses, class_logits, labels, box_features
        else:
            detections, detector_losses, class_logits, labels = self.roi_heads(features, proposals, images.image_sizes, targets)
            return detections, proposal_losses, detector_losses, class_logits, labels

    def forward_train(self, images, targets, env_num=None, mini_batch_size=None):
        """
        Implement algorithms here
        """
        if self.algorithm == 'ERM':
            _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(images, targets)
            
            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            return losses

        elif self.algorithm == 'IRM':
            penalty_weight = 1

            def _irm_penalty(logits, y):
                device = "cuda" if logits[0][0].is_cuda else "cpu"
                scale = torch.tensor(1.).to(device).requires_grad_()
                loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
                loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
                grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
                grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
                result = torch.sum(grad_1 * grad_2)
                return result

            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0., 'penalty': 0.}

            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))

                losses['loss_classifier'] += detector_losses['loss_classifier'] / env_num
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                losses['penalty'] += _irm_penalty(class_logits, torch.cat(labels, dim=0)) * penalty_weight / env_num

            return losses
        
        elif self.algorithm == 'VREx':
            penalty_weight = 1

            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0., 'penalty': 0.}
            penalty_loss = torch.zeros(env_num)
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))
                
                losses['loss_classifier'] += detector_losses['loss_classifier'] / env_num
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                penalty_loss[i] = F.cross_entropy(class_logits, torch.cat(labels, dim=0))
                
            penalty_mean = penalty_loss.mean()
            losses['penalty'] += penalty_weight * ((penalty_loss - penalty_mean) ** 2).mean()
            
            return losses
        
        elif self.algorithm == 'SD':
            sd_reg = 0.1
            
            _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(images, targets)
            penalty = sd_reg * (class_logits ** 2).mean()
            losses = {}
            losses.update(detector_losses)
            losses.update(proposal_losses)
            losses.update({'penalty': penalty})
            
            return losses
        
        elif self.algorithm == 'GroupDRO':
            group_dro_eta = 1e-2
            q = torch.ones(env_num)
            cls_losses = torch.zeros(env_num)
            
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))
                
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                
                cls_losses[i] = F.cross_entropy(class_logits, torch.cat(labels, dim=0))
                q[i] *= (group_dro_eta * cls_losses[i].data).exp()
            
            q /= q.sum()
            losses['loss_classifier'] += torch.dot(cls_losses, q)
            
            return losses
        
        elif self.algorithm == 'MLDG':
            mldg_lr = 1e-3
            mldg_wd = 0.
            mldg_inner_lr = 1e-3
            mldg_inner_wd = 0.
            mldg_beta = 1.
            
            def random_pairs_of_minibatches(images, targets, env_num, mini_batch_size):
                perm = torch.randperm(env_num).tolist()
                pairs = list()
                for i in range(env_num):
                    mbs = mini_batch_size
                    j = i + 1 if i < (env_num - 1) else 0
                    
                    images_i, targets_i = list(images[perm[i] * mbs:(perm[i] + 1) * mbs]), list(targets[perm[i] * mbs:(perm[i] + 1) * mbs])
                    images_j, targets_j = list(images[perm[j] * mbs:(perm[j] + 1) * mbs]), list(targets[perm[j] * mbs:(perm[j] + 1) * mbs])
                    
                    pairs.append(((images_i, targets_i), (images_j, targets_j)))
                return pairs
            
            objective = 0
#             self_opt = torch.optim.Adam(self.parameters(), lr=mldg_lr, weight_decay=mldg_wd)
#             self_opt.zero_grad()

            for p in self.parameters():
                if p.grad is None:
                    p.grad = torch.zeros_like(p)
                    
            for (xi, yi), (xj, yj) in random_pairs_of_minibatches(images, targets, env_num, mini_batch_size):
                inner_net = copy.deepcopy(self)
                inner_opt = torch.optim.Adam(inner_net.parameters(), lr=mldg_inner_lr, weight_decay=mldg_inner_wd)
                
                inner_net.train()
                inner_net.algorithm = 'ERM'
                
                inner_obj = inner_net.forward_train(xi, yi)
                inner_obj = sum(loss for loss in inner_obj.values())
                
                inner_opt.zero_grad()
                inner_obj.backward()
                inner_opt.step()
                
                for p_tgt, p_src in zip(self.parameters(), inner_net.parameters()):
                    if p_src.grad is not None:
                        p_tgt.grad.data.add_(p_src.grad.data / env_num)
                
                objective += inner_obj

                loss_inner_j = inner_net.forward_train(xj, yj)
                loss_inner_j = sum(loss for loss in loss_inner_j.values())
                
                for ip in inner_net.parameters():
                    if ip.requires_grad is False:
                        ip.requires_grad = True
                
                grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True)

                objective += (mldg_beta * loss_inner_j)

                for p, g_j in zip(self.parameters(), grad_inner_j):
                    if g_j is not None:
                        p.grad.data.add_(mldg_beta * g_j.data / env_num)
            
#             self_opt.step()
            return {'objective_and_step': objective}
        
        elif self.algorithm == 'AndMask':
            tau = 1
            lr = 1e-3
            wd = 0.
            
            def mask_grads(tau, gradients, params):
                for param, grads in zip(params, gradients):
                    grads = torch.stack(grads, dim=0)
                    grad_signs = torch.sign(grads)
                    mask = torch.mean(grad_signs, dim=0).abs() >= tau
                    mask = mask.to(torch.float32)
                    avg_grad = torch.mean(grads, dim=0)

                    mask_t = (mask.sum() / mask.numel())
                    param.grad = mask * avg_grad
                    param.grad *= (1. / (1e-10 + mask_t))

                return 0
            
            opt = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
            
            for i in self.parameters():
                if i.requires_grad is False:
                    i.requires_grad = True
            
            mean_loss = 0
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            param_gradients = [[] for _ in self.parameters()]
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))
                
                losses['loss_classifier'] = detector_losses['loss_classifier']
                losses['loss_box_reg'] = detector_losses['loss_box_reg']
                losses['loss_objectness'] = proposal_losses['loss_objectness']
                losses['loss_rpn_box_reg'] = proposal_losses['loss_rpn_box_reg']
                
                env_loss = sum(v for v in losses.values())
#                 env_loss = F.cross_entropy(logits, y)
                mean_loss += env_loss / env_num

                env_grads = autograd.grad(env_loss, self.parameters())
                for grads, env_grad in zip(param_gradients, env_grads):
                    grads.append(env_grad)

            opt.zero_grad()
            mask_grads(tau, param_gradients, self.parameters())
            
            opt.step()

            return {'objective': mean_loss}
        
        elif self.algorithm == 'SAndMask':
            tau = 1.
            k = 1e+1
            lr = 1e-3
            wd = 0.
            
            def mask_grads(k, tau, gradients, params):
                '''
                Here a mask with continuous values in the range [0,1] is formed to control the amount of update for each
                parameter based on the agreement of gradients coming from different environments.
                '''
                device = gradients[0][0].device
                for param, grads in zip(params, gradients):
                    grads = torch.stack(grads, dim=0)
                    avg_grad = torch.mean(grads, dim=0)
                    grad_signs = torch.sign(grads)
                    gamma = torch.tensor(1.0).to(device)
                    grads_var = grads.var(dim=0)
                    grads_var[torch.isnan(grads_var)] = 1e-17
                    lam = (gamma * grads_var).pow(-1)
                    mask = torch.tanh(k * lam * (torch.abs(grad_signs.mean(dim=0)) - tau))
                    mask = torch.max(mask, torch.zeros_like(mask))
                    mask[torch.isnan(mask)] = 1e-17
                    mask_t = (mask.sum() / mask.numel())
                    param.grad = mask * avg_grad
                    param.grad *= (1. / (1e-10 + mask_t))
            
            opt = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
            
            for i in self.parameters():
                if i.requires_grad is False:
                    i.requires_grad = True
                     
            mean_loss = 0
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            param_gradients = [[] for _ in self.parameters()]
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))
                
                losses['loss_classifier'] = detector_losses['loss_classifier']
                losses['loss_box_reg'] = detector_losses['loss_box_reg']
                losses['loss_objectness'] = proposal_losses['loss_objectness']
                losses['loss_rpn_box_reg'] = proposal_losses['loss_rpn_box_reg']

                env_loss = sum(v for v in losses.values())
#                 env_loss = F.cross_entropy(logits, y)
                mean_loss += env_loss / env_num
    
                env_grads = autograd.grad(env_loss, self.parameters())
                for grads, env_grad in zip(param_gradients, env_grads):
                    grads.append(env_grad)

            opt.zero_grad()
            # gradient masking applied here
            mask_grads(k, tau, param_gradients, self.parameters())
            opt.step()

            return {'objective': mean_loss}
        
        elif self.algorithm == 'IGA':
            penalty = 1000

            for i in self.parameters():
                if i.requires_grad is False:
                    i.requires_grad = True
            
            total_loss = 0
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            grads = []
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]))

                losses['loss_classifier'] += detector_losses['loss_classifier']
                losses['loss_box_reg'] += detector_losses['loss_box_reg']
                losses['loss_objectness'] += proposal_losses['loss_objectness']
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg']
                
                env_loss = sum(v for v in losses.values())
                
                total_loss += env_loss

                env_grad = autograd.grad(env_loss, self.parameters(), retain_graph=True)

                grads.append(env_grad)

            mean_loss = total_loss / env_num
            mean_grad = autograd.grad(mean_loss, self.parameters(), retain_graph=True)

            # compute trace penalty
            penalty_value = 0
            for grad in grads:
                for g, mean_g in zip(grad, mean_grad):
                    penalty_value += (g - mean_g).pow(2).sum()

#             objective = mean_loss + penalty * penalty_value

            return {'losses': mean_loss, 'penalty': penalty * penalty_value}

        elif self.algorithm == 'Mixup':
            mixup_alpha = 0.2
            
            def random_pairs_of_minibatches(images, targets, env_num, mini_batch_size):
                perm = torch.randperm(env_num).tolist()
                pairs = list()
                for i in range(env_num):
                    mbs = mini_batch_size
                    j = i + 1 if i < (env_num - 1) else 0
                    
                    images_i, targets_i = list(images[perm[i] * mbs:(perm[i] + 1) * mbs]), list(targets[perm[i] * mbs:(perm[i] + 1) * mbs])
                    images_j, targets_j = list(images[perm[j] * mbs:(perm[j] + 1) * mbs]), list(targets[perm[j] * mbs:(perm[j] + 1) * mbs])
                    
                    pairs.append(((images_i, targets_i), (images_j, targets_j)))
                return pairs
            
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            for (xi, yi), (xj, yj) in random_pairs_of_minibatches(images, targets, env_num, mini_batch_size):
                lam = np.random.beta(mixup_alpha, mixup_alpha)

                x = list()
                for i, j in zip(xi, xj):
                    x.append(lam * i + (1 - lam) * j)
                                
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(x, yi)
                
                losses['loss_classifier'] += lam * detector_losses['loss_classifier']
                losses['loss_box_reg'] += lam * detector_losses['loss_box_reg']
                losses['loss_objectness'] += lam * proposal_losses['loss_objectness']
                losses['loss_rpn_box_reg'] += lam * proposal_losses['loss_rpn_box_reg']
                
                
                _, proposal_losses, detector_losses, class_logits, labels = self.extract_features(x, yi)

                losses['loss_classifier'] += (1 - lam) * detector_losses['loss_classifier']
                losses['loss_box_reg'] += (1 - lam) * detector_losses['loss_box_reg']
                losses['loss_objectness'] += (1 - lam) * proposal_losses['loss_objectness']
                losses['loss_rpn_box_reg'] += (1 - lam) * proposal_losses['loss_rpn_box_reg']

            objective = sum(v for v in losses.values()) / env_num
            
            return {'losses': objective}
        
        elif self.algorithm == 'IB_ERM':
            ib_lambda = 1e2
            ib_penalty_anneal_iters = 500
            ib_penalty_weight = (ib_lambda if self.update_count >= ib_penalty_anneal_iters else 0.0)        
            
            nll = 0.
            ib_penalty = 0.
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0., 'ib_penalty': 0.}
            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels, features = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]), 
                    is_feature_hook=True)
                
                losses['loss_classifier'] += detector_losses['loss_classifier'] / env_num
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                losses['ib_penalty'] += (ib_penalty_weight * features.var(dim=0).mean()) / env_num
            
            self.update_count += 1
            return losses
        
        elif self.algorithm == 'IB_IRM':
            irm_lambda = 1e2
            irm_penalty_anneal_iters = 500
            ib_lambda = 1e2
            ib_penalty_anneal_iters = 500
            
            irm_penalty_weight = (irm_lambda if self.update_count >= irm_penalty_anneal_iters else 1.0)
            ib_penalty_weight = (ib_lambda if self.update_count >= ib_penalty_anneal_iters else 0.0)        

            def _irm_penalty(logits, y):
                device = "cuda" if logits[0][0].is_cuda else "cpu"
                scale = torch.tensor(1.).to(device).requires_grad_()
                loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
                loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
                grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
                grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
                result = torch.sum(grad_1 * grad_2)
                return result

            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0., 'irm_penalty': 0., 'ib_penalty': 0.}

            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels, features = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]), 
                    is_feature_hook=True)

                losses['loss_classifier'] += detector_losses['loss_classifier'] / env_num
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                losses['irm_penalty'] += _irm_penalty(class_logits, torch.cat(labels, dim=0)) * irm_penalty_weight / env_num
                losses['ib_penalty'] += (ib_penalty_weight * features.var(dim=0).mean()) / env_num

            self.update_count += 1
            return losses
        
        elif self.algorithm == 'MMD' or self.algorithm == 'CORAL':
            kernel_type = 'gaussian' if self.algorithm == 'MMD' else 'mean and covariance'
            mmd_gamma = 1.
            
            def my_cdist(x1, x2):
                x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
                x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
                res = torch.addmm(x2_norm.transpose(-2, -1),
                                  x1,
                                  x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
                return res.clamp_min_(1e-30)

            def gaussian_kernel(x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                                   1000]):
                D = my_cdist(x, y)
                K = torch.zeros_like(D)

                for g in gamma:
                    K.add_(torch.exp(D.mul(-g)))

                return K

            def mmd(x, y):
                if kernel_type == "gaussian":
                    Kxx = gaussian_kernel(x, x).mean()
                    Kyy = gaussian_kernel(y, y).mean()
                    Kxy = gaussian_kernel(x, y).mean()
                    return Kxx + Kyy - 2 * Kxy
                else:
                    mean_x = x.mean(0, keepdim=True)
                    mean_y = y.mean(0, keepdim=True)
                    cent_x = x - mean_x
                    cent_y = y - mean_y
                    cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
                    cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

                    mean_diff = (mean_x - mean_y).pow(2).mean()
                    cova_diff = (cova_x - cova_y).pow(2).mean()

                    return mean_diff + cova_diff
            
            features = list()
            
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0., 'penalty': 0.}

            for i in range(env_num):
                mbs = mini_batch_size
                _, proposal_losses, detector_losses, class_logits, labels, feature = self.extract_features(
                    list(images[i * mbs:(i + 1) * mbs]),
                    list(targets[i * mbs:(i + 1) * mbs]), 
                    is_feature_hook=True)
                
                losses['loss_classifier'] += detector_losses['loss_classifier'] / env_num
                losses['loss_box_reg'] += detector_losses['loss_box_reg'] / env_num
                losses['loss_objectness'] += proposal_losses['loss_objectness'] / env_num
                losses['loss_rpn_box_reg'] += proposal_losses['loss_rpn_box_reg'] / env_num
                
                features.append(feature)
            
            for i in range(env_num):
                for j in range(i + 1, env_num):
                    losses['penalty'] += (mmd_gamma * mmd(features[i], features[j]))
                    
            if env_num > 1:
                losses['penalty'] /= (env_num * (env_num - 1) / 2)
                
            return losses
            
        elif self.algorithm == 'RSC':
            self.drop_f = 1/3
            self.drop_b = 1/3
            self.classifier = self.roi_heads.box_predictor.cls_score
            
            _, proposal_losses, detector_losses, all_p, all_y, all_f = self.extract_features(images, targets, is_feature_hook=True)
            
            losses = {'loss_classifier': 0., 'loss_box_reg': 0., 'loss_objectness': 0., 'loss_rpn_box_reg': 0.}
            losses['loss_box_reg'] = detector_losses['loss_box_reg']
            losses['loss_objectness'] = proposal_losses['loss_objectness']
            losses['loss_rpn_box_reg'] = proposal_losses['loss_rpn_box_reg']

            # one-hot labels
            all_y = torch.tensor([item.cpu().detach().numpy() for item in all_y]).cuda().view(-1)
#             print(all_p.shape)
#             print(all_y.shape)
            all_o = torch.nn.functional.one_hot(all_y, all_p.shape[1])

            # Equation (1): compute gradients with respect to representation
            all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]

            # Equation (2): compute top-gradient-percentile mask
            percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
            percentiles = torch.Tensor(percentiles)
            percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
            mask_f = all_g.lt(percentiles.to('cuda')).float()

            # Equation (3): mute top-gradient-percentile activations
            all_f_muted = all_f * mask_f

            # Equation (4): compute muted predictions
            all_p_muted = self.classifier(all_f_muted)

            # Section 3.3: Batch Percentage
            all_s = F.softmax(all_p, dim=1)
            all_s_muted = F.softmax(all_p_muted, dim=1)
            changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
            percentile = np.percentile(changes.detach().cpu(), self.drop_b)
            mask_b = changes.lt(percentile).float().view(-1, 1)
            mask = torch.logical_or(mask_f, mask_b).float()

            # Equations (3) and (4) again, this time mutting over examples
            all_p_muted_again = self.classifier(all_f * mask)

            # Equation (5): update
            losses['loss_classifier'] = F.cross_entropy(all_p_muted_again, all_y)
            
            return losses
            
        else:
            raise ValueError("Algorithm is not supported !!")

    def simple_test(self, images, ori_img_sizes):
        images, targets = self.transform(images, None)
        detections = self.extract_features(images, transform=False)[0]
        detections = self.transform.postprocess(detections, images.image_sizes, ori_img_sizes)
        return detections

    def forward(self, images, targets=None, env_num=None, mini_batch_size=None):
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")

        if self.training:
            assert targets is not None

            for target in targets:
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                        raise ValueError("Expected target boxes to be a tensor"
                                         "of shape [N, 4], got {:}.".format(
                            boxes.shape))
                else:
                    raise ValueError("Expected target boxes to be of type "
                                     "Tensor, got {:}.".format(type(boxes)))

            return self.forward_train(images, targets, env_num, mini_batch_size)

        else:           
            original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
            for img in images:
                val = img.shape[-2:]
                assert len(val) == 2
                original_image_sizes.append((val[0], val[1]))

            return self.simple_test(images, original_image_sizes)


class TwoMLPHead(nn.Module):
    """
    two fc layers after roi pooling/align
    :param in_channels: number of input channels
    :param representation_size: size of the intermediate representation
    """

    def __init__(self, in_channels, representation_size):
        super(TwoMLPHead, self).__init__()

        self.fc6 = nn.Linear(in_channels, representation_size)
        self.fc7 = nn.Linear(representation_size, representation_size)

    def forward(self, x):
        x = x.flatten(start_dim=1)

        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))

        return x


class FastRCNNPredictor(nn.Module):
    """
    Standard classification + bounding box regression layers for Fast R-CNN.
    :param in_channels: number of input channels
    :param num_classes: number of output classes (including background)
    """

    def __init__(self, in_channels, num_classes):
        super(FastRCNNPredictor, self).__init__()
        self.cls_score = nn.Linear(in_channels, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)

    def forward(self, x):
        if x.dim() == 4:
            assert list(x.shape[2:]) == [1, 1]
        x = x.flatten(start_dim=1)
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)

        return scores, bbox_deltas


class FasterRCNN(FasterRCNNBase):
    """
    Implementation of Faster R-CNN.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or inference mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
          between 0 and H and 0 and W
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses for both the RPN and the R-CNN.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
          0 and H and 0 and W
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores or each prediction

    :param backbone: (nn.Module), the network used to compute the features for the model.
            It should contain a out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or and OrderedDict[Tensor].
    :param num_classes: (int), number of output classes of the model (including the background).
            If box_predictor is specified, num_classes should be None.
    :param min_size: (int), minimum size of the image to be rescaled before feeding it to the backbone
    :param max_size: (int), maximum size of the image to be rescaled before feeding it to the backbone
    :param image_mean: (Tuple[float, float, float]):, mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
    :param image_std: (Tuple[float, float, float]), std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
    :param rpn_anchor_generator: (AnchorGenerator), module that generates the anchors for a set of feature maps.
    :param rpn_head: (nn.Module),  module that computes the objectness and regression deltas from the RPN
    :param rpn_pre_nms_top_n_train:(int),  number of proposals to keep before applying NMS during training
    :param rpn_pre_nms_top_n_test: (int), number of proposals to keep before applying NMS during testing
    :param rpn_post_nms_top_n_train: (int), number of proposals to keep after applying NMS during training
    :param rpn_post_nms_top_n_test: (int), number of proposals to keep after applying NMS during testing
    :param rpn_nms_thresh: (float), NMS threshold used for postprocessing the RPN proposals
    :param rpn_fg_iou_thresh:(float), minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
    :param rpn_bg_iou_thresh:(float), maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
    :param rpn_batch_size_per_image: (int), number of anchors that are sampled during training of the RPN
            for computing the loss
    :param rpn_positive_fraction: (float), proportion of positive anchors in a mini-batch during training
            of the RPN
    :param box_roi_pool:(MultiScaleRoIAlign), the module which crops and resizes the feature maps in
            the locations indicated by the bounding boxes
    :param box_head:(nn.Module), module that takes the cropped feature maps as input
    :param box_predictor:(nn.Module), module that takes the output of box_head and returns the
            classification logits and box regression deltas.
    :param box_score_thresh:(float),during inference, only return proposals with a classification score
            greater than box_score_thresh
    :param box_nms_thresh: (float), NMS threshold for the prediction head. Used during inference
    :param box_detections_per_img: (int), maximum number of detections per image, for all classes.
    :param box_fg_iou_thresh:(float): minimum IoU between the proposals and the GT box so that they can be
            considered as positive during training of the classification head
    :param box_bg_iou_thresh: (float), maximum IoU between the proposals and the GT box so that they can be
            considered as negative during training of the classification head
    :param box_batch_size_per_image: (int), number of proposals that are sampled during training of the
            classification head
    :param box_positive_fraction: (float), proportion of positive proposals in a mini-batch during training
            of the classification head
    :param bbox_reg_weights: (Tuple[float, float, float, float]), weights for the encoding/decoding of the
            bounding boxes
    """

    def __init__(self, backbone, algorithm='ERM', num_classes=None,
                 # transform parameter
                 min_size=800, max_size=1333,  # preprocess minimum and maximum size
                 image_mean=None, image_std=None,  # mean and std in preprocess

                 # RPN parameters
                 rpn_anchor_generator=None, rpn_head=None,
                 rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,  # kept proposals before nms
                 rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,  # kept proposals after nms
                 rpn_nms_thresh=0.7,  # iou threshold during nms
                 rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,  # bg/fg threshold
                 rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,  # number of samples and fraction

                 # Box parameters
                 box_roi_pool=None, box_head=None, box_predictor=None,

                 # remove low threshold target
                 box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
                 box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
                 box_batch_size_per_image=512, box_positive_fraction=0.25,
                 bbox_reg_weights=None
                 ):

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels"
                "specifying the number of output channels  (assumed to be the"
                "same for all the levels"
            )

        assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
        assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))

        if num_classes is not None:
            if box_predictor is not None:
                raise ValueError("num_classes should be None when box_predictor "
                                 "is specified")
        else:
            if box_predictor is None:
                raise ValueError("num_classes should not be None when box_predictor "
                                 "is not specified")

        # output channels of the backbone
        out_channels = backbone.out_channels

        if rpn_head is None:
            rpn_head = RPNHead(
                out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
            )

        rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
        rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)

        rpn = RegionProposalNetwork(
            rpn_anchor_generator, rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_batch_size_per_image, rpn_positive_fraction,
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)

        # two fc layer after roi pooling
        if box_head is None:
            resolution = box_roi_pool.output_size[0]
            representation_size = 1024
            box_head = TwoMLPHead(
                out_channels * resolution ** 2,
                representation_size
            )

        # get prediction
        if box_predictor is None:
            representation_size = 1024
            box_predictor = FastRCNNPredictor(
                representation_size,
                num_classes)

        roi_heads = RoIHeads(
            # box
            box_roi_pool, box_head, box_predictor,
            box_fg_iou_thresh, box_bg_iou_thresh,
            box_batch_size_per_image, box_positive_fraction,
            bbox_reg_weights,
            box_score_thresh, box_nms_thresh, box_detections_per_img)

        transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)

        super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform, algorithm)
