import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
import copy

def create_optimizer(args, base_optimizer, model, **kwargs):
    if args.opt == 'sam':
        optimizer = ESAM(model.parameters(), base_optimizer=base_optimizer, rho=args.rho, beta=1.0, gamma=1.0)
    elif args.opt == 'esam':
        optimizer = ESAM(model.parameters(), base_optimizer=base_optimizer, rho=args.rho, beta=args.beta, gamma=args.gamma)
    elif args.opt == 'fisher-esam':
        optimizer = ESAM_Mask(model.parameters(), model=model, base_optimizer=base_optimizer, 
                            num_samples=args.num_samples,
                            keep_ratio=args.keep_ratio,
                            mask_iter_e=args.mask_iter_e,
                            rho=args.rho,
                            gamma=args.gamma)
    elif args.opt == 'esam-abeta':
        optimizer = ESAM_adaptiveBeta(model.parameters(), base_optimizer=base_optimizer, rho=args.rho, beta=1.0, gamma=args.gamma)
    
    
    else:
        raise ValueError
    
    return optimizer


class ESAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, beta=1.0, gamma=1.0, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        self.beta = beta
        self.gamma = gamma

        defaults = dict(rho=rho,adaptive=adaptive, **kwargs)
        super(ESAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

        for group in self.param_groups:
            group["rho"] = rho
            group["adaptive"] = adaptive
        self.paras = None

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        #first order sum 
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-7) / self.beta
            for p in group["params"]:
                p.requires_grad = True 
                if p.grad is None: continue
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w * 1)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None or not self.state[p]: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
                self.state[p]["e_w"] = 0

                if random.random() > self.beta:
                    p.requires_grad = False

        self.base_optimizer.step()  # do the actual "sharpness-aware" update
        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        # first forward-backward step
        # model.require_backward_grad_sync = False
        # model.require_forward_param_sync = True

        # first grad
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass
        loss, sample_size, logging_output = closure()
        l_before=loss.clone().detach()

        # first step to w + e(w) 
        self.first_step(zero_grad=True)

        output = logits
        return_loss = loss.clone().detach()
        with torch.no_grad():
            l_after, sample_size, logging_output = closure()
            instance_sharpness = l_after - l_before

            prob = self.gamma
            if prob >=0.99:
                indices = range(len(sample_size))
            else:
                position = int(len(sample_size) * prob)
                cutoff,_ = torch.topk(instance_sharpness, position)
                cutoff = cutoff[-1]
                indices = [instance_sharpness > cutoff] 

        closure()
        # second step
        self.second_step(True)
        return loss, sample_size, logging_output
 

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

class ESAM_Mask(ESAM):
    def __init__(self, params, base_optimizer, model, num_samples=128, keep_ratio=0.1, mask_iter_e=1, rho=0.05, gamma=1, adaptive=False, **kwargs):
        self.model = model.cuda()
        self.num_samples = num_samples
        self.keep_ratio = keep_ratio
        self.mask_iter_e = mask_iter_e
        assert(0.0 <= keep_ratio <= 1.0)
        super().__init__(params=params, 
                        base_optimizer=base_optimizer, 
                        rho=rho,
                        beta=1.0,
                        gamma=gamma, 
                        adaptive=adaptive, **kwargs)
        self.mask = {}

    def init_mask(self):
        for name, param in self.model.named_parameters():
            self.mask[name] = torch.zeros_like(param, dtype=torch.float32, requires_grad=False).cuda()
        
        self.remove_based_partial('bias')
        self.remove_based_nntype(nn.BatchNorm1d)
        self.remove_based_nntype(nn.BatchNorm2d)


    def remove_weight(self, name):
        if name in list(self.mask.keys()):
            print('Removing `{}` (size:{};(param:{}).'.format(name, self.mask[name].shape, self.mask[name].numel()))
            self.mask.pop(name)

    def remove_based_nntype(self, nn_type):
        for name, module in self.model.named_modules():
            if isinstance(module, nn_type):
                self.remove_weight(name)
                self.remove_weight(name + '.weight')
                self.remove_weight(name + '.bias')

    def remove_based_partial(self, partial_name):
        for name in list(self.mask.keys()):
            if partial_name in name:
                print('Removing `{}` (size:{};(param:{}).'.format(name, self.mask[name].shape, self.mask[name].numel()))
                self.mask.pop(name)


    def set_fisher_mask(self, traindata):
        fisher_dict = {}
        for name, param in self.model.named_parameters():
            if name in self.mask:
                fisher_dict[name] = torch.zeros_like(param, requires_grad=False).cuda()
        criterion = nn.CrossEntropyLoss()
        train_dataloader = DataLoader(
            dataset=traindata,
            batch_size=1,
            num_workers=0,
            shuffle=True,
        )

        for idx, (image, label) in enumerate(train_dataloader):
            if idx >= self.num_samples:
                break
            image, label = image.cuda(), label.cuda()

            output = self.model(image)
            loss = criterion(output, label)
            loss.backward()

            for name, param in self.model.named_parameters():
                if name in self.mask:
                    fisher_dict[name] += torch.square(param.grad).data
            self.model.zero_grad()
        
        # get topk mask
        param_shape = {}
        fisher_value = []
        all_param_size = 0
        for name, fisher_info in fisher_dict.items():
            if name in self.mask:
                param_shape[name] = fisher_info.shape
                fisher_value.append(fisher_info.view(-1))
                all_param_size += fisher_info.numel()
        
        fisher_value = torch.cat(fisher_value, 0)

        keep_num = int(all_param_size * self.keep_ratio)
        assert keep_num > 0

        param_to_be_update = torch.topk(fisher_value, keep_num)[1]
        mask_position = torch.zeros_like(fisher_value, dtype=torch.float, requires_grad=False).cuda()
        mask_position[param_to_be_update] = 1
        assert fisher_value.numel() == self.mask_info()[1]

        # update to self.mask
        start_idx = 0
        for name, shape in param_shape.items():
            end_idx = start_idx + torch.prod(torch.tensor(shape))
            self.mask[name] = copy.deepcopy(mask_position[start_idx: end_idx].reshape(shape)).cuda()
            self.mask[name].requires_grad = False
            start_idx = end_idx
        assert start_idx == len(mask_position)


    def mask_info(self):
        all_param = 0
        zero_param = 0
        nonzero_param = 0
        for name, mask_value in self.mask.items():
            all_param += mask_value.numel()
            nonzero_param += torch.sum(mask_value).item()
            zero_param += mask_value.numel() - torch.sum(mask_value).item()
        sparse_ratio = zero_param / float(all_param)
        info = 'Mask has {:.3f}Mb param to choose, {:.3f}Mb params fire, {:.3f}Mb params freeze, sparse ratio:{:.3f}'.format(all_param /1024. /1024.,
                                                                                                                            nonzero_param /1024. /1024., 
                                                                                                                            zero_param /1024. /1024., 
                                                                                                                            sparse_ratio)
        return [info, all_param, nonzero_param, zero_param, sparse_ratio]

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        #first order sum 
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-7) / self.beta
            # for p in group["params"]:
            for name, p in self.model.named_parameters():
                p.requires_grad = True 
                if p.grad is None: continue
                #original sam 
                # e_w = p.grad * scale.to(p)
                #asam 
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                if name in self.mask:
                    e_w.data = e_w.data * self.mask[name]
                p.add_(e_w * 1)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w


        if zero_grad: self.zero_grad()


class ESAM_adaptiveBeta(ESAM):
    def __init__(self, params, base_optimizer, rho=0.05, beta=1, gamma=1, adaptive=False, alpha=1.0, **kwargs):
        self.alpha = alpha
        
        self.his_grad = {}
        self.mask = {}
        self.mask_num = 0
        self.total_num = 0
        self.kp_ratio = 1
        
        super().__init__(params, base_optimizer, rho, beta, gamma, adaptive, **kwargs)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        
        #first order sum 
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-7) / self.beta / self.kp_ratio
            for i, p in enumerate(group["params"]):
                p.requires_grad = True 
                if p.grad is None: continue
                # original sam 
                # e_w = p.grad * scale.to(p)
                #asam 
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                # save grad
                self.his_grad[i] = copy.deepcopy(p.grad.data)
                # apply mask
                if len(self.mask) != 0: e_w.data = e_w.data * self.mask[i]
                  
                
                p.add_(e_w * 1)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()


    @torch.no_grad()
    def second_step(self, zero_grad=False):
        self.mask_num = 0
        self.total_num = 0
        for group in self.param_groups:
            for i, p in enumerate(group["params"]):
                if p.grad is None or not self.state[p]: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
                self.state[p]["e_w"] = 0
                
                ratio = (p.grad - self.his_grad[i]) / self.his_grad[i]
                mask = torch.where(ratio > self.alpha, 1.0, 0.0)
                self.mask_num += torch.sum(mask).cpu().numpy()
                self.total_num += mask.numel()

                self.mask[i] = mask

                if random.random() > self.beta:
                    p.requires_grad = False
        self.kp_ratio = float(self.mask_num) / self.total_num
        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()