"""
The script is adapted from https://github.com/NOVAglow646/NIPS22-MAT-and-LDAT-for-OOD/blob/master/domainbed/algorithms.py
Author: Jeng-Lin (John) Li
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable
from torch.optim import optimizer
import collections
import os
import copy
import numpy as np
from collections import defaultdict


#from domainbed import networks
#from domainbed.lib.misc import random_pairs_of_minibatches, ParamDict
from torchvision import utils as vutils
from torchvision import transforms

ALGORITHMS = [
    'ERM',
    'Fish',
    'IRM',
    'GroupDRO',
    'Mixup',
    'MLDG',
    'CORAL',
    'MMD',
    'DANN',
    'CDANN',
    'MTL',
    'SagNet',
    'ARM',
    'VREx',
    'RSC',
    'SD',
    'ANDMask',
    'SANDMask',  # SAND-mask
    'IGA',
    'SelfReg',
    'AT',
    'UAT',
    'MAT',
    'LDAT',
    'LAT'
]


def get_algorithm_class(algorithm_name):
    """Return the algorithm class with the given name."""
    if algorithm_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
    return globals()[algorithm_name]


def load_model(model_path, algo):
    model_dict = torch.load(model_path)
    algorithm_class = get_algorithm_class(algo)
    #print(model_dict['model_hparams'])#(2,28,28)
    algorithm = algorithm_class(model_dict['model_input_shape'], 
        model_dict['model_num_classes'],
        model_dict['model_num_domains'], 
        model_dict['model_hparams']
    )
    algorithm.load_state_dict(model_dict['model_dict'])
    return algorithm


class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain generalization algorithm.
    Subclasses should implement the following:
    - update()
    - predict()
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(Algorithm, self).__init__()
        self.hparams = hparams

    def update(self, minibatches, unlabeled=None):
        """
        Perform one update step, given a list of (x, y) tuples for all
        environments.

        Admits an optional list of unlabeled minibatches from the test domains,
        when task is domain_adaptation.
        """
        raise NotImplementedError

    def predict(self, x):
        raise NotImplementedError


class ERM(Algorithm):
    """
    Empirical Risk Minimization (ERM)
    """
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams)
        self.featurizer = networks.Featurizer(input_shape, self.hparams)
        self.classifier = networks.Classifier(self.featurizer.n_outputs, num_classes,
                                              self.hparams['nonlinear_classifier']
                                             )

        self.network = nn.Sequential(self.featurizer, self.classifier)
        # self.network = networks.Lin(input_shape, num_classes)
        self.optimizer = torch.optim.Adam(self.network.parameters(),
                                          lr=self.hparams["lr"],
                                          weight_decay=self.hparams['weight_decay'])

    def update(self, minibatches, unlabeled=None, epoch=None, delta_dir=None):
        all_x = torch.cat([x for x, y in minibatches])
        all_y = torch.cat([y for x, y in minibatches])
        loss = F.cross_entropy(self.predict(all_x), all_y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {'loss': loss.item()}

    def predict(self, x):
        return self.network(x)


class LDAT(ERM):
    '''Low-rank Decomposed Adversarial training'''
    def __init__(self, input_shape, num_classes, num_domains, hparams):
        super(LDAT, self).__init__(input_shape, num_classes, num_domains, hparams)
        #self.delta = None
        self.B = None
        self.A = None
        self.register_buffer('update_count', torch.tensor([0]))

        if hparams['is_cmnist'] == 0:
            self.std = torch.Tensor(
                [0.229, 0.224, 0.225]
            ).view(3, 1, 1).unsqueeze(0).expand(
                hparams['batch_size'], *input_shape
            ).cuda()
            #print(self.std.shape)
            self.mean = torch.Tensor(
                [0.485, 0.456, 0.406]
            ).view(3, 1, 1).unsqueeze(0).expand(
                hparams['batch_size'], *input_shape).cuda()
            self.high = torch.Tensor(
                [
                    (1 - 0.485) / 0.229, 
                    (1 - 0.456) / 0.224, 
                    (1 - 0.406) / 0.225
                ]
            ).cuda()
            self.low = torch.Tensor(
                [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]
            ).cuda()

        else:

            self.high = torch.Tensor([1, 1]).cuda()
            self.low = torch.Tensor([0, 0]).cuda()
            self.std = 1
            self.mean = 0

    def projection(self, x, eps):
        '''Project x into a ball with radius eps'''
        return x * min(1, eps / (torch.linalg.norm(x) + 1e-10))


    def clamp(self, x):  
        '''Crop x to [self.low, self.high]'''
        x = x.permute(0, 2, 3, 1)
        x = torch.max(torch.min(x, self.high), self.low)
        x = x.permute(0, 3, 1, 2)
        return x

    def FGSM(self, x, y, B, A, bsz, num_iter=1, norm=2):
        '''Inner Fast Gradient Sign Attack maximization'''
        delta = self.projection(torch.bmm(B, A), self.hparams['at_eps']).cuda()
        epsilon = self.hparams["at_eps"]
        x_denorm = x * self.std + self.mean
        for t in range(num_iter):
            loss = F.cross_entropy(
                self.predict(self.clamp((x_denorm + delta - self.mean) / self.std)), 
                y, reduction='none'
            )
            loss = torch.mean(loss.clamp(0, 2))  # adding beta as upperbound
            (A_grad, B_grad) = autograd.grad(loss, [A, B], retain_graph=False)
            if norm == 2:
                B_grad = self.hparams["at_alpha"] * B_grad.detach() \
                    / (torch.linalg.norm(B_grad) + 1e-10)  
                A_grad = self.hparams["at_alpha"] * A_grad.detach() \
                    / (torch.linalg.norm(A_grad) + 1e-10)
                B = B + B_grad * self.hparams['B_lr']
                A = A + A_grad * self.hparams['A_lr']
            elif norm == 'inf':  
                grad = self.hparams["at_alpha"] * grad.detach().sign()
                d = (d + grad).clamp(min=-epsilon, max=epsilon)
            else:
                raise NotImplementedError

        return B.detach(), A.detach()

    def update(self, minibatches, unlabeled=None, epoch=None, delta_dir=None):
        '''Conduct one epoch of optimization'''
        bsz = len(minibatches[0][0])
        env_num = len(minibatches)
        channel_num = minibatches[0][0][0].shape[0]
        x_h = minibatches[0][0][0].shape[1]
        x_w = minibatches[0][0][0].shape[2]
        if self.B == None and self.A == None:
            self.B = torch.randn(
                (env_num, channel_num, x_h, int(self.hparams["at_cb_rank"])), 
                requires_grad=True
            ).cuda()
            self.A = torch.randn(
                (env_num, channel_num, int(self.hparams["at_cb_rank"]), x_w), 
                requires_grad=True
            ).cuda()
        delta = torch.zeros(env_num, channel_num, x_h, x_w).cuda()

        self.optimizer.zero_grad()
        loss = 0.
        for i, (x, y) in enumerate(minibatches):
            '''Update A and B for each domain i. (x,y) is batch data.'''
            self.B[i], self.A[i] = self.FGSM(x, y, self.B[i], self.A[i], bsz, norm=2)
            x_denorm = x * self.std + self.mean
            delta[i] = self.projection(torch.bmm(self.B[i], self.A[i]), 
                self.hparams["at_eps"]
            )  #!
            x_adv = self.clamp((x_denorm + delta[i] - self.mean) / self.std).detach()
            loss_adv = F.cross_entropy(self.predict(x_adv), y)
            loss += loss_adv.item()
            loss_adv.backward()
        self.optimizer.step()
        del loss_adv
        self.update_count += 1
        return {'loss': loss}

