from woods.objectives.ERM import ERM
import torch
import torch.nn.functional as F

import copy
from woods.models.lstm import LSTM
from woods.models.mnist import MNIST_LSTM
from woods.models.ActNetwork import ActNetwork

class Transfer(ERM):
    '''Algorithm 1 in Quantifying and Improving Transferability in Domain Generalization (https://arxiv.org/abs/2106.03632)'''
    ''' tries to ensure transferability among source domains, and thus transferabiilty between source and target'''
    def __init__(self, model, dataset, optimizer, hparams):
        super(Transfer, self).__init__(model, dataset, optimizer, hparams)
    # def __init__(self, input_shape, num_classes, num_domains, hparams):
    #     super(Transfer, self).__init__(input_shape, num_classes, num_domains, hparams)
        self.register_buffer('update_count', torch.tensor([0]))
        self.d_steps_per_g = hparams['d_steps_per_g']

        # Number of domain definition
        self.num_domains = len(dataset.ENVS)
        self.num_train_domains = self.num_domains - 1 if dataset.test_env is not None else self.num_domains

        # Architecture 
        self.model = model

        # Quick Fix of the deep copy problem for stored datasets
        if isinstance(model, LSTM):
            self.model.dataset = None
            self.adv_classifier = copy.deepcopy(model).to(self.model.device)
            self.model.dataset = dataset
        elif isinstance(model, MNIST_LSTM):
            self.model.home_lstm.dataset = None
            self.adv_classifier = copy.deepcopy(model).to(self.model.device)
            self.model.home_lstm.dataset = dataset
        else:
            self.model.dataset = None
            self.adv_classifier = copy.deepcopy(model).to(self.model.device)
        self.model.dataset = dataset
        self.adv_classifier.dataset = dataset
        # No need to load state dict because it is deepcopied
        self.adv_classifier.load_state_dict(self.model.state_dict())

        # Optimizers
        def get_optimizer_params(optimizer):
            for p_grp in optimizer.param_groups:
                return p_grp
        opt_params = get_optimizer_params(optimizer)
        if self.hparams['gda']:
            self.optimizer = torch.optim.SGD(self.adv_classifier.parameters(), lr=opt_params['lr']) 
        # else:
        #     self.optimizer = torch.optim.Adam(
        #     (list(self.featurizer.parameters()) + list(self.classifier.parameters())),
        #         lr=opt_params['lr'],
        #         weight_decay=opt_params['weight_decay'])

        self.adv_opt = torch.optim.SGD(self.adv_classifier.parameters(), lr=self.hparams['lr_d']) 

    def update(self):

        # Put model into training mode
        self.model.train()

        # Get next batch
        X, Y = self.dataset.get_next_batch()

        preds, _ = self.model(X)
        loss = self.dataset.loss(preds, Y)
        # self.optimizer.zero_grad()
        # loss.backward()
        # self.optimizer.step()

        gap = self.hparams['t_lambda'] * self.loss_gap(X,Y)

        objective = loss + gap
        self.optimizer.zero_grad()
        objective.backward()
        self.optimizer.step()

        self.adv_classifier.load_state_dict(self.model.state_dict())
        for _ in range(self.d_steps_per_g):
            self.adv_opt.zero_grad()
            gap = -self.hparams['t_lambda'] * self.loss_gap(X,Y)
            gap.backward()
            self.adv_opt.step()
            updated_adv_classifier = self.proj(self.hparams['delta'], self.adv_classifier.get_classifier_network(), self.model.get_classifier_network())
            self.adv_classifier.get_classifier_network().load_state_dict(updated_adv_classifier.state_dict())

    def loss_gap(self, X, Y):
        ''' compute gap = max_i loss_i(h) - min_j loss_j(h), return i, j, and the gap for a single batch'''
        device = X.device
        max_env_loss, min_env_loss =  torch.tensor([-float('inf')], device=device), torch.tensor([float('inf')], device=device)

        # Get adv prediction
        _, feats = self.model(X)
        pred = self.adv_classifier.classify(feats)
        losses = self.dataset.loss_by_domain(pred, Y, self.num_train_domains)

        min_env_loss = min(losses)
        max_env_loss = max(losses)

        return max_env_loss - min_env_loss
        
    def distance(self, h1, h2):
        ''' distance of two networks (h1, h2 are classifiers)'''
        dist = 0.
        for param in h1.state_dict():
            h1_param, h2_param = h1.state_dict()[param], h2.state_dict()[param]
            dist += torch.norm(h1_param - h2_param) ** 2  # use Frobenius norms for matrices
        return torch.sqrt(dist)


    def proj(self, delta, adv_h, h):
        ''' return proj_{B(h, \delta)}(adv_h), Euclidean projection to Euclidean ball'''
        ''' adv_h and h are two classifiers'''
        dist = self.distance(adv_h, h)
        if dist <= delta:
            return adv_h
        else:
            ratio = delta / dist
            for param_h, param_adv_h in zip(h.parameters(), adv_h.parameters()):
                param_adv_h.data = param_h + ratio * (param_adv_h - param_h)
            # print("distance: ", distance(adv_h, h))
            return adv_h

    def update_second(self, minibatches, unlabeled=None):
        device = "cuda" if minibatches[0][0].is_cuda else "cpu"
        self.update_count = (self.update_count + 1) % (1 + self.d_steps_per_g)
        if self.update_count.item() == 1:
            all_x = torch.cat([x for x,y in minibatches])
            all_y = torch.cat([y for x,y in minibatches])
            loss = F.cross_entropy(self.predict(all_x), all_y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            del all_x, all_y
            gap = self.hparams['t_lambda'] * loss_gap(minibatches, self, device)
            self.optimizer.zero_grad()
            gap.backward()
            self.optimizer.step()
            self.adv_classifier.load_state_dict(self.classifier.state_dict())
            return {'loss': loss.item(), 'gap': gap.item()}
        else:
            self.adv_opt.zero_grad()
            gap = -self.hparams['t_lambda'] * loss_gap(minibatches, self, device)
            gap.backward()
            self.adv_opt.step()
            self.adv_classifier = proj(self.hparams['delta'], self.adv_classifier, self.classifier)
            return {'gap': -gap.item()}
