import numpy as np
import torch
from .strategy import Strategy

import copy
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader

class SAALSampling(Strategy):
    def __init__(self, train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args):
        raise NotImplementedError('SAALSampling is deprecated since this directly accesses to the model')
        super(SAALSampling, self).__init__(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        self.rho = args['rho']
    
    def get_max_perturbed_loss(self, dataset): 
        dataloader = DataLoader(dataset, shuffle=False, **self.args['loader_te_args'])
        self.net.eval()

        max_perturbed_loss = []
        for x, _, __ in dataloader: 
            model_copy = copy.deepcopy(self.net)
            model_copy = model_copy.to(self.device)

            x = x.to(self.device)
            out, _ = model_copy(x)
            loss = F.cross_entropy(out, out.argmax(dim=1), reduction='none')
            loss.mean().backward()
            
            norm = torch.norm(
                torch.stack([(torch.abs(p)*p.grad).norm(p=2) for p in model_copy.parameters()]), 
                p=2
            )
            scale = self.rho / (norm + 1e-12)
            with torch.no_grad(): 
                for p in model_copy.parameters():
                    e_w = (torch.pow(p, 2)) * p.grad * scale.to(p)
                    p.add_(e_w)
            
            output_updated, _ = model_copy(x)
            loss_updated = F.cross_entropy(output_updated, output_updated.argmax(dim=1), reduction='none')
            max_perturbed_loss.append(loss_updated.cpu().detach().data)
            
            del model_copy, out, loss, loss_updated, output_updated, x
        max_perturbed_loss = torch.cat(max_perturbed_loss, dim=0)
        return max_perturbed_loss

    def query(self, n):
        idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
        dataset = Subset(self.train_raw_dataset, idxs_unlabeled)
        max_perturbed_loss = self.get_max_perturbed_loss(dataset)
        return idxs_unlabeled[max_perturbed_loss.sort()[1][:n]]
        
