import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os

from model_wrapper import BasicModelWrapper, MNBaBDeepPolyModelWrapper, PGDModelWrapper
from optimizers import Adam, ClipUp, Optimizer

if os.environ['TORCH_USE_FP64']=="1":
    torch.set_default_dtype(torch.float64)
else:
    torch.set_default_dtype(torch.float32)

from AIDomains.concrete_layers import Normalization
from AIDomains.zonotope import HybridZonotope
from torch.optim.lr_scheduler import MultiStepLR

from typing import Callable, Iterable, List, Optional, Union
from utils import Identity, Scheduler, Statistics, seed_everything
from tqdm.auto import tqdm
import math
import copy
import numpy as np
import sys
from time import time
from regularization import compute_L1_reg, compute_fast_reg

import torch.distributed as dist

# sys.path.append(f'path/to/prima4complete/')
# from src.utilities.config import IntermediateBoundsMethod

StateDict = dict[str,torch.Tensor]

def get_loss_value(model_wrapper: BasicModelWrapper, x, y, eps, precomp_bounds=False):
    model_wrapper.net.eval()
    model_wrapper.store_box_bounds = False
    model_wrapper.summary_accu_stat = True
    model_wrapper.freeze_BN = True
    # x,y = batch
    # x, y = x.to(args.device), y.to(args.device)

    interm_bounds = None
    if precomp_bounds:
        if hasattr(model_wrapper,'interm_bounds') and model_wrapper.interm_bounds is not None:
            interm_bounds = model_wrapper.interm_bounds
        else:
            print('Warning: called get_loss_value with precomp_bounds=True, but model_wrapper has no saved interm_bounds')

    # print(eps,interm_bounds is not None,y[:10])
    # self._restore_normalization_param()
    (loss, nat_loss, robust_loss), (nat_accu, robust_accu) = model_wrapper.compute_model_stat(x, y, eps, intermediate_bounds=interm_bounds)
    if hasattr(model_wrapper,'anet'): model_wrapper.anet.reset_input_bounds()

    # print(y[:10])
    # (loss, nat_loss, cert_loss), (nat_accu, cert_accu) = model_wrapper.compute_model_stat(x, y, args.test_eps) # already called eval, so do not need to close BN again in the common step.
    # val_nat_accu, val_cert_accu, val_loss = test_loop(model_wrapper, args.train_eps, val_loader, args.device, args, max_batches=num_batches,disable_tqdm=True)

    return loss.item()

def get_grad_value(model_wrapper:BasicModelWrapper, x, y, eps, optimizer, args):
    device = args.device
    model_wrapper.net.train()
    model_wrapper.summary_accu_stat = False
    model_wrapper.freeze_BN = False

    # Design of TAPS: use IBP for annealing to increase speed and performance.
    model_wrapper.num_steps = args.train_steps

    # Design of fast regularization: use fast reg only for annealing.
    fast_reg = (args.fast_reg > 0) and eps < args.train_eps
    if fast_reg:
        model_wrapper.store_box_bounds = True

    # Define custom logging behavior for the first epoch here if verbose-first-epoch is set.

    x, y = x.to(device), y.to(device)
    # x, y = x.to(torch.float64), y.to(torch.int64)
    # print(y[:10],y.shape)
    # print(model_wrapper.net[-1].weight.dtype)
    # eps = args.train_eps

    # optimizer.zero_grad()
    (loss, nat_loss, cert_loss), (nat_accu, cert_accu), (is_nat_accu, is_cert_accu) = model_wrapper.compute_model_stat(x, y, eps)
    
    if fast_reg:
        # add fast reg to the loss
        reg_eps = max(eps, args.min_eps_reg)
        if reg_eps != eps:
            # recompute bounds for fast reg
            model_wrapper.net.reset_bounds()
            abs_x = HybridZonotope.construct_from_noise(x, reg_eps, "box")
            model_wrapper.net(abs_x)
        reg = args.fast_reg * (1 - reg_eps/args.train_eps) * compute_fast_reg(model_wrapper.net, reg_eps)
        loss = loss + reg

    if args.L1_reg > 0:
        loss = loss + args.L1_reg * compute_L1_reg(model_wrapper.net)

    model_wrapper.net.reset_bounds()
    # print("loss inside get_grad_value",type(model_wrapper),loss)
    loss.backward()
    if args.grad_clip:
        torch.nn.utils.clip_grad_norm_(model_wrapper.net.parameters(), args.grad_clip)
    # print("loss after backward",type(model_wrapper),loss)
    # print(loss)
    
    grad_dict = {}
    for k, v in model_wrapper.net.state_dict(keep_vars=True).items():
        if hasattr(v,'grad'):
            grad_dict[k] = v.grad.clone() if v.grad is not None else torch.zeros(1,device=v.device)
    
    return grad_dict, loss.item()


def evo_train_loop(evo_wrapper: "PGPEEvoWrapper", eps_scheduler:Scheduler, train_loader, epoch_idx, device, args, verbose:bool=False, nep_log=None, extra_adv_attack:bool=True):
    model_wrapper: BasicModelWrapper = evo_wrapper.model_wrapper
    model_wrapper.net.train()
    model_wrapper.summary_accu_stat = True
    model_wrapper.freeze_BN = False
    ibp_stability, dp_stability, intermediate_bounds = [None] * 3

    nat_accu_stat, cert_accu_stat, loss_stat, mean_eval_stat = Statistics.get_statistics(4)

    if extra_adv_attack:
        attack_wrapper = PGDModelWrapper(model_wrapper.net, nn.CrossEntropyLoss(), model_wrapper.input_dim, device, args)
        attack_wrapper.num_steps = 200
        adv_accu_stat = Statistics()

    pbar = tqdm(train_loader, disable=args.rank)
    # with torch.no_grad():
    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(device), y.to(device)
        eps = eps_scheduler.getcurrent(epoch_idx * len(train_loader) + batch_idx)
        model_wrapper.robust_weight = 1
        evo_wrapper.current_eps = eps
        ret_all = args.log_unstable or args.precomp_bounds
        model_wrapper.net.load_state_dict(evo_wrapper.state_dict_from_vector(evo_wrapper.center))
            
        with torch.no_grad():
            (loss, nat_loss, cert_loss), (nat_accu, cert_accu), *rest = model_wrapper.compute_model_stat(x, y, eps, return_all=ret_all, compute_bounds=args.precomp_bounds)
        if rest and isinstance(model_wrapper, MNBaBDeepPolyModelWrapper):
            print(rest,type(model_wrapper))
            ibp_stability, dp_stability, model_wrapper.interm_bounds = rest
            # print(model_wrapper.interm_bounds is None)
            if hasattr(model_wrapper,'anet'): model_wrapper.anet.reset_input_bounds()
        if extra_adv_attack:
            with torch.no_grad():
                adv_loss, adv_accu, is_adv_accu = attack_wrapper.get_robust_stat_from_input_noise(eps, x, y)
                adv_accu_stat.update(adv_accu.item(), len(x))
        
        if verbose and args.rank == 0:
            print(f"Batch {batch_idx}:", nat_accu, cert_accu, loss.item())

        l_base = loss.item()
        loss_stat.update(l_base, len(x))
        nat_accu_stat.update(nat_accu, len(x))
        cert_accu_stat.update(cert_accu, len(x))

        grad_mu, grad_sigma, mean_loss = evo_wrapper.estimate_grads(x,y, eps, l_base, precomp_bounds=args.precomp_bounds)
        mean_eval_stat.update(mean_loss, len(x))
        
        # TODO: sync grads if using multiple actors before calling descend
        mean_eval_tensor = torch.tensor([1, mean_loss],device=args.device)
        if args.use_ddp:
            dist.all_reduce(mean_eval_tensor, dist.ReduceOp.SUM)
            dist.all_reduce(grad_mu, dist.ReduceOp.SUM)
            dist.all_reduce(grad_sigma, dist.ReduceOp.SUM)
            tot = mean_eval_tensor[0]
            mean_loss = (mean_eval_tensor[1] / tot).item()
            grad_mu = grad_mu / tot
            grad_sigma = grad_sigma / tot

        with torch.no_grad():
            evo_wrapper.descend(grad_mu,grad_sigma)

        postfix_str = f"nat_accu: {nat_accu_stat.avg:.3f}, cert_accu: {cert_accu_stat.avg:.3f}, train_loss: {loss_stat.avg:.3f}, mean_eval: {mean_loss:.3f}"

        pbar.set_postfix_str(postfix_str)
        if args.rank == 0:
            evo_wrapper.log_with_neptune()
            if args.log_unstable:
                evo_wrapper.log_unstable(ibp_stability,dp_stability)

    # TODO: maybe also sync these, but not really relevant unless we change the dataloader
    # tot = nat_accu_stat.n
    # stats = torch.tensor([1, nat_accu_stat.avg, cert_accu_stat.avg, adv_accu_stat.avg, loss_stat.avg, mean_eval_stat.avg]) * tot
    # if args.use_ddp:
    #     pass

    if extra_adv_attack:
        return nat_accu_stat.avg, cert_accu_stat.avg, adv_accu_stat.avg, loss_stat.avg, mean_eval_stat.avg
    else:
        return nat_accu_stat.avg, cert_accu_stat.avg, loss_stat.avg, mean_eval_stat.avg

class BasicEvoWrapper():
    def __init__(self, model_wrapper:BasicModelWrapper, device, args, subbatch_size:int=None, nep_log=None):
        
        self.model_wrapper = model_wrapper
        self.original_state_dict:"StateDict" = copy.deepcopy(model_wrapper.net.state_dict())
        self.frozen_layers = self.parse_frozen_layers(args.freeze_layers)

        self.all_param_names = sorted(self.original_state_dict.keys())
        self.all_param_shapes = {k:v.shape for k,v in self.original_state_dict.items()}
        self.all_param_numels = {k:v.numel() for k,v in self.original_state_dict.items()}
        self.trainable_param_names = [pn for pn in self.all_param_names if self.is_trainable(pn)]
        
        self.original_center = self.vectorize_state_dict(self.original_state_dict)
        self.device = device
        self.args = args
        self.nep_log = nep_log
        self.current_eps=0
        self.optimizers = []

    def step_lr(self, factor):
        for opt in self.optimizers:
            opt.step_lr(factor)

    def parse_frozen_layers(self, layers, include_zero=True):
        layers_to_freeze = []
        if include_zero:
            layers_to_freeze.append('0')
        layers_to_freeze += layers.split(',')
        return layers_to_freeze

    def is_trainable(self, param):
        for fl in self.frozen_layers:
            if f'.{fl}.' in param: return False
        return True        

    def vectorize_state_dict(self, sd:"StateDict"):
        vectors = []
        for pn in self.trainable_param_names:
            vectors.append(sd[pn].flatten())

        return torch.cat(vectors)

    def state_dict_from_vector(self, vector:torch.Tensor):
        sd = {}

        start = 0
        for pn in self.trainable_param_names:
            shape = self.all_param_shapes[pn]
            numel = self.all_param_numels[pn]
            sd[pn] = vector[start:start+numel].reshape(shape)
            start += numel

        for k, v in self.original_state_dict.items():
            if k not in sd:
                sd[k] = v.clone()

        return sd

    def estimate_grads(self, x, y, **kwargs):
        raise NotImplementedError()

class PGPEEvoWrapper(BasicEvoWrapper):
    '''
    torch_net must be connected to the abstract net wrapped in model_wrapper.

    If use_current_std = True, std_init will be overriden by the std of current params in the torch_net for every param group.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, center_lr, std_lr,  device, args, popsize:int=1000, symmetric:bool=True, std_init:float=1e-2, use_current_std:bool=True, std_min=None, std_max=None, subbatch_size:int=None, nep_log=None, optimizer="adam"):
        # record current params to initialize the searcher solution
        super().__init__(model_wrapper, device, args, subbatch_size, nep_log)

        # self.center_lr = center_lr
        # self.std_lr = std_lr

        self.center = self.vectorize_state_dict(self.original_state_dict)
        self.std = torch.ones_like(self.center) * std_init
        self.std_min = std_min
        self.std_max = std_max
        self.popsize = popsize
        self.symmetric = symmetric
        self.scale_grads = args.scale_grads
        self.max_std_change = 0.02

        print('PGPE center lr', center_lr)
        print('PGPE std lr', std_lr)
        self.center_lr = center_lr
        self.std_lr = std_lr

        opt_args = dict(
            solution_shape=self.center.shape,
            dtype=self.center.dtype,
            device=self.device,
        )
        args_center = opt_args | dict(
            stepsize=self.center_lr,
        )
        args_std = opt_args | dict(
            stepsize=self.std_lr,
        )

        if optimizer == 'adam':
            self.optimizer_center = Adam(**args_center)
            self.optimizer_std = Adam(**args_std)
            # self.optimizer_std:Optimizer = None
        elif optimizer == 'clipup':
            self.optimizer_center = ClipUp(**args_center)
            self.optimizer_std = ClipUp(**args_std)
        else:
            self.optimizer_center:Optimizer = None
            self.optimizer_std:Optimizer = None

        # opt_state_path = os.path.join(args.save_root,'optimizer.ckpt')
        # print("opt path",opt_state_path,flush=True)
        # if args.resume_train:
        #     if os.path.exists(opt_state_path):
        #         searcher._optimizer._optim.load_state_dict(torch.load(opt_state_path))
        #         # print(args.device,list(optimizer.state_dict()['state'].keys()))
        #         # for kk in list(optimizer.state_dict()['state'].keys()):
        #         #     print(args.device,kk,optimizer.state_dict()['state'][kk]['exp_avg'].shape)
        #         # print(args.device,list(optimizer.state_dict()['param_groups']))
        #         # print(args.device,list(optimizer.state_dict().keys()))
        #         print(args.device,"loaded optimizer state:",opt_state_path,flush=True)
        #     # else:
        #     for group in searcher._optimizer._optim.param_groups:
        #         group.setdefault('initial_lr', group['lr'])

        # self.lr_scheduler = MultiStepLR(searcher._optimizer._optim, args.lr_milestones, gamma=args.lr_decay_factor, last_epoch=args.start_epoch-1) if optimizer in ["adam", "sgd"] else None
        # self.lr_scheduler = MultiStepLR(searcher._optimizer, args.lr_milestones, gamma=args.lr_decay_factor)
        self.lr_scheduler = None

        # self.searcher = searcher
        # self.std_locked = True
        # if logging:
        #     StdOutLogger(self.searcher, interval=1)

    def estimate_grads(self, x, y, eps, l_base, precomp_bounds=False):
        if self.symmetric:
            return self.estimate_grads_sym(x,y, eps, l_base, precomp_bounds=precomp_bounds)
        else:
            return self.estimate_grads_asym(x,y, eps, l_base, precomp_bounds=precomp_bounds)

    def estimate_grads_asym(self, x, y, eps, l_base, precomp_bounds=False):
        raise NotImplementedError()

    def estimate_grads_sym(self, x, y, eps, l_base, precomp_bounds=False):
        model_wrapper:BasicModelWrapper = self.model_wrapper

        grad_mu = torch.zeros_like(self.center)
        grad_sigma = torch.zeros_like(self.std)
        sig = self.std
        sig_sq = self.std ** 2
        total_loss = 0

        for d in self.sample_directions():
            model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
            l_plus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)
            model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center - d))
            l_minus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)

            gmu = d * (l_plus - l_minus)
            gsig = (d ** 2 - sig_sq) / sig * ((l_plus + l_minus) / 2 - l_base)
            if self.scale_grads:
                gmu = gmu / (l_plus + l_minus)
                gsig = gsig / l_base

            grad_mu += gmu
            grad_sigma += gsig

            total_loss += (l_plus + l_minus) / 2

        self.mean_eval = total_loss / self.popsize

        return grad_mu/self.popsize, grad_sigma/self.popsize, total_loss/self.popsize

    def sample_directions(self):

        for i in range(self.popsize):
            yield torch.randn_like(self.std) * self.std

    def descend(self, grad_mu:torch.Tensor, grad_sigma:torch.Tensor):

        if self.optimizer_center is not None:
            step = self.optimizer_center.descend(grad_mu)
        else:
            step = -grad_mu * self.center_lr
        self.grad_size = grad_mu.norm(2)
        self.grad_step_size = step.norm(2)
        # print('PGPE grad size',self.grad_size)
        # print('PGPE grad step size',self.grad_step_size)
        # print('PGPE center lr', self.center_lr)
        # print('PGPE opt lr', self.optimizer_center.stepsize)
        self.center += step
        
        self.total_change = (self.center - self.original_center).norm(2)

        old_std = self.std.clone()
        if self.optimizer_std is not None:
            step = self.optimizer_std.descend(grad_sigma)
        else:
            step = -grad_sigma * self.std_lr
        self.std += step
        self.std = torch.clip(self.std,(1-self.max_std_change)*old_std,(1+self.max_std_change)*old_std)
        self.std = torch.clip(self.std, self.std_min, self.std_max)

    def log_unstable(self, ibp_unstable=None, dp_unstable=None):
        if self.nep_log is not None:
            self.nep_log["result/ibp_unstable"].append(ibp_unstable)
            self.nep_log["result/dp_unstable"].append(dp_unstable)

    def log_with_neptune(self):
        if self.nep_log is not None:
            if self.lr_scheduler is not None:
                self.nep_log["result/lr"].append(self.optimizer_center.stepsize)
            self.nep_log["result/eps"].append(self.current_eps)
            self.nep_log["result/mean_eval"].append(self.mean_eval/self.args.fitness_factor)
            self.nep_log["result/mean_pgpe_std_mean"].append(self.std.mean())
            # self.nep_log["result/mean_pgpe_std_min"].append(self.searcher.status["stdev"].min())
            # self.nep_log["result/mean_pgpe_std_max"].append(self.searcher.status["stdev"].max())
            # self.nep_log["result/mean_pgpe_std_median"].append(self.searcher.status["stdev"].median())
            # print(self.searcher.status["stdev"][173:183].cpu().tolist())
            if hasattr(self,"delta") and self.delta is not None:
                self.nep_log["result/real_loss_delta"].append(self.delta)
                self.nep_log["result/loss_error"].append(self.loss_error)
            if hasattr(self,"grad_step_size") and self.grad_step_size is not None:
                self.nep_log["result/grad_step_size"].append(self.grad_step_size)
            if hasattr(self,"grad_size") and self.grad_size is not None:
                self.nep_log["result/grad_size"].append(self.grad_size)
            if hasattr(self, "total_change") and self.total_change is not None:
                self.nep_log["result/total_change"].append(self.total_change)

class GSMEvoWrapper(BasicEvoWrapper):
    '''
    torch_net must be connected to the abstract net wrapped in model_wrapper.

    If use_current_std = True, std_init will be overriden by the std of current params in the torch_net for every param group.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, center_lr, std_lr,  device, args, popsize:int=32, symmetric:bool=False, std_init:float=1e-2, use_current_std:bool=True, std_min=None, std_max=None, subbatch_size:int=None, nep_log=None, optimizer="adam"):
        # record current params to initialize the searcher solution
        super().__init__(model_wrapper, device, args, subbatch_size, nep_log)

        # self.center_lr = center_lr
        # self.std_lr = std_lr

        self.center = self.vectorize_state_dict(self.original_state_dict)
        self.std = torch.ones_like(self.center) * std_init
        self.std_min = std_min
        self.std_max = std_max
        self.popsize = popsize
        self.symmetric = symmetric
        self.scale_grads = args.scale_grads
        self.max_std_change = 0.02
        self.grad_cleaner = torch.optim.SGD(self.model_wrapper.net.parameters(), lr=0)

        print('GSM center lr', center_lr)
        print('GSM std lr', std_lr)
        self.center_lr = center_lr
        self.std_lr = std_lr

        opt_args = dict(
            solution_shape=self.center.shape,
            dtype=self.center.dtype,
            device=self.device,
        )
        args_center = opt_args | dict(
            stepsize=self.center_lr,
        )
        args_std = opt_args | dict(
            stepsize=self.std_lr,
        )

        if optimizer == 'adam':
            self.optimizer_center = Adam(**args_center)
            self.optimizer_std = Adam(**args_std)
            # self.optimizer_std:Optimizer = None
        elif optimizer == 'clipup':
            self.optimizer_center = ClipUp(**args_center)
            self.optimizer_std = ClipUp(**args_std)
        else:
            self.optimizer_center:Optimizer = None
            self.optimizer_std:Optimizer = None

        # opt_state_path = os.path.join(args.save_root,'optimizer.ckpt')
        # print("opt path",opt_state_path,flush=True)
        # if args.resume_train:
        #     if os.path.exists(opt_state_path):
        #         searcher._optimizer._optim.load_state_dict(torch.load(opt_state_path))
        #         # print(args.device,list(optimizer.state_dict()['state'].keys()))
        #         # for kk in list(optimizer.state_dict()['state'].keys()):
        #         #     print(args.device,kk,optimizer.state_dict()['state'][kk]['exp_avg'].shape)
        #         # print(args.device,list(optimizer.state_dict()['param_groups']))
        #         # print(args.device,list(optimizer.state_dict().keys()))
        #         print(args.device,"loaded optimizer state:",opt_state_path,flush=True)
        #     # else:
        #     for group in searcher._optimizer._optim.param_groups:
        #         group.setdefault('initial_lr', group['lr'])

        # self.lr_scheduler = MultiStepLR(searcher._optimizer._optim, args.lr_milestones, gamma=args.lr_decay_factor, last_epoch=args.start_epoch-1) if optimizer in ["adam", "sgd"] else None
        # self.lr_scheduler = MultiStepLR(searcher._optimizer, args.lr_milestones, gamma=args.lr_decay_factor)
        self.lr_scheduler = None

        # self.searcher = searcher
        # self.std_locked = True
        # if logging:
        #     StdOutLogger(self.searcher, interval=1)

    def estimate_grads(self, x, y, eps, l_base, precomp_bounds=False):
        if self.symmetric:
            return self.estimate_grads_sym(x,y, eps, l_base, precomp_bounds=precomp_bounds)
        else:
            return self.estimate_grads_asym(x,y, eps, l_base, precomp_bounds=precomp_bounds)

    def estimate_grads_asym(self, x, y, eps, l_base, precomp_bounds=False):
        model_wrapper:BasicModelWrapper = self.model_wrapper

        grad_mu = torch.zeros_like(self.center)
        grad_sigma = torch.zeros_like(self.std)
        sig = self.std
        sig_sq = self.std ** 2
        total_loss = 0

        for d in self.sample_directions():
            model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
            sample_grad, sample_loss = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
            
            grad_mu += self.vectorize_state_dict(sample_grad)

            total_loss += sample_loss

        self.mean_eval = total_loss / self.popsize

        return grad_mu/self.popsize, grad_sigma/self.popsize, total_loss/self.popsize

    def estimate_grads_sym(self, x, y, eps, l_base, precomp_bounds=False):
        model_wrapper:BasicModelWrapper = self.model_wrapper

        grad_mu = torch.zeros_like(self.center)
        grad_sigma = torch.zeros_like(self.std)
        sig = self.std
        sig_sq = self.std ** 2
        total_loss = 0

        for cnt, d0 in enumerate(self.sample_directions()):
            for d in [d0, -d0]:
                model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
                sample_grad, sample_loss = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
                
                grad_mu += self.vectorize_state_dict(sample_grad)

                total_loss += sample_loss
            if cnt == self.popsize//2:
                break

        self.mean_eval = total_loss / self.popsize

        return grad_mu/self.popsize, grad_sigma/self.popsize, total_loss/self.popsize

    def sample_directions(self):

        for i in range(self.popsize):
            yield torch.randn_like(self.std) * self.std

    def descend(self, grad_mu:torch.Tensor, grad_sigma:torch.Tensor):

        if self.optimizer_center is not None:
            step = self.optimizer_center.descend(grad_mu)
        else:
            step = -grad_mu * self.center_lr
        self.grad_size = grad_mu.norm(2)
        self.grad_step_size = step.norm(2)
        # print('PGPE grad size',self.grad_size)
        # print('PGPE grad step size',self.grad_step_size)
        # print('PGPE center lr', self.center_lr)
        # print('PGPE opt lr', self.optimizer_center.stepsize)
        self.center += step
        
        self.total_change = (self.center - self.original_center).norm(2)

        old_std = self.std.clone()
        if self.optimizer_std is not None:
            step = self.optimizer_std.descend(grad_sigma)
        else:
            step = -grad_sigma * self.std_lr
        self.std += step
        self.std = torch.clip(self.std,(1-self.max_std_change)*old_std,(1+self.max_std_change)*old_std)
        self.std = torch.clip(self.std, self.std_min, self.std_max)

    def log_unstable(self, ibp_unstable=None, dp_unstable=None):
        if self.nep_log is not None:
            self.nep_log["result/ibp_unstable"].append(ibp_unstable)
            self.nep_log["result/dp_unstable"].append(dp_unstable)

    def log_with_neptune(self):
        if self.nep_log is not None:
            if self.lr_scheduler is not None:
                self.nep_log["result/lr"].append(self.optimizer_center.stepsize)
            self.nep_log["result/eps"].append(self.current_eps)
            self.nep_log["result/mean_eval"].append(self.mean_eval/self.args.fitness_factor)
            self.nep_log["result/mean_pgpe_std_mean"].append(self.std.mean())
            # self.nep_log["result/mean_pgpe_std_min"].append(self.searcher.status["stdev"].min())
            # self.nep_log["result/mean_pgpe_std_max"].append(self.searcher.status["stdev"].max())
            # self.nep_log["result/mean_pgpe_std_median"].append(self.searcher.status["stdev"].median())
            # print(self.searcher.status["stdev"][173:183].cpu().tolist())
            if hasattr(self,"delta") and self.delta is not None:
                self.nep_log["result/real_loss_delta"].append(self.delta)
                self.nep_log["result/loss_error"].append(self.loss_error)
            if hasattr(self,"grad_step_size") and self.grad_step_size is not None:
                self.nep_log["result/grad_step_size"].append(self.grad_step_size)
            if hasattr(self,"grad_size") and self.grad_size is not None:
                self.nep_log["result/grad_size"].append(self.grad_size)
            if hasattr(self, "total_change") and self.total_change is not None:
                self.nep_log["result/total_change"].append(self.total_change)

class MixedEvoWrapper(BasicEvoWrapper):
    '''
    torch_net must be connected to the abstract net wrapped in model_wrapper.

    If use_current_std = True, std_init will be overriden by the std of current params in the torch_net for every param group.
    '''
    def __init__(self, model_wrapper:BasicModelWrapper, center_lr, std_lr,  device, args, popsize:int=32, GSM_popsize:int=32, symmetric:bool=True, std_init:float=1e-2, use_current_std:bool=True, std_min=None, std_max=None, subbatch_size:int=None, nep_log=None, optimizer="adam", PGPE_weight:float=1.):
        # record current params to initialize the searcher solution
        super().__init__(model_wrapper, device, args, subbatch_size, nep_log)

        # self.center_lr = center_lr
        # self.std_lr = std_lr

        self.center = self.vectorize_state_dict(self.original_state_dict)
        self.std = torch.ones_like(self.center) * std_init
        self.std_min = std_min
        self.std_max = std_max
        self.popsize = popsize
        self.GSM_popsize = GSM_popsize if GSM_popsize >= 0 else popsize
        print('GSM popsize', self.GSM_popsize)
        self.symmetric = symmetric
        self.scale_grads = args.scale_grads
        self.max_std_change = 0.02
        self.grad_cleaner = torch.optim.SGD(self.model_wrapper.net.parameters(), lr=0)

        print('GSM center lr', center_lr)
        print('GSM std lr', std_lr)
        self.center_lr = center_lr
        self.std_lr = std_lr

        self.PGPE_weight = PGPE_weight
        print('PGPE weight', PGPE_weight)

        opt_args = dict(
            solution_shape=self.center.shape,
            dtype=self.center.dtype,
            device=self.device,
        )
        args_center = opt_args | dict(
            stepsize=self.center_lr,
        )
        args_std = opt_args | dict(
            stepsize=self.std_lr,
        )

        if optimizer == 'adam':
            self.optimizer_center = Adam(**args_center)
            self.optimizer_std = Adam(**args_std)
            # self.optimizer_std:Optimizer = None
        elif optimizer == 'clipup':
            self.optimizer_center = ClipUp(**args_center)
            self.optimizer_std = ClipUp(**args_std)
        else:
            self.optimizer_center:Optimizer = None
            self.optimizer_std:Optimizer = None

        # opt_state_path = os.path.join(args.save_root,'optimizer.ckpt')
        # print("opt path",opt_state_path,flush=True)
        # if args.resume_train:
        #     if os.path.exists(opt_state_path):
        #         searcher._optimizer._optim.load_state_dict(torch.load(opt_state_path))
        #         # print(args.device,list(optimizer.state_dict()['state'].keys()))
        #         # for kk in list(optimizer.state_dict()['state'].keys()):
        #         #     print(args.device,kk,optimizer.state_dict()['state'][kk]['exp_avg'].shape)
        #         # print(args.device,list(optimizer.state_dict()['param_groups']))
        #         # print(args.device,list(optimizer.state_dict().keys()))
        #         print(args.device,"loaded optimizer state:",opt_state_path,flush=True)
        #     # else:
        #     for group in searcher._optimizer._optim.param_groups:
        #         group.setdefault('initial_lr', group['lr'])

        # self.lr_scheduler = MultiStepLR(searcher._optimizer._optim, args.lr_milestones, gamma=args.lr_decay_factor, last_epoch=args.start_epoch-1) if optimizer in ["adam", "sgd"] else None
        # self.lr_scheduler = MultiStepLR(searcher._optimizer, args.lr_milestones, gamma=args.lr_decay_factor)
        self.lr_scheduler = None

        # self.searcher = searcher
        # self.std_locked = True
        # if logging:
        #     StdOutLogger(self.searcher, interval=1)

    def estimate_grads(self, x, y, eps, l_base, precomp_bounds=False):
        # print(type(self),"estimate_grads, symmetric =",self.symmetric, "eps =", eps, flush=True)
        if self.symmetric:
            return self.estimate_grads_sym(x,y, eps, l_base, precomp_bounds=precomp_bounds)
        else:
            return self.estimate_grads_asym(x,y, eps, l_base, precomp_bounds=precomp_bounds)

    def estimate_grads_asym(self, x, y, eps, l_base, precomp_bounds=False):
        raise NotImplementedError()
        model_wrapper:BasicModelWrapper = self.model_wrapper

        grad_mu = torch.zeros_like(self.center)
        grad_sigma = torch.zeros_like(self.std)
        sig = self.std
        sig_sq = self.std ** 2
        total_loss = 0

        for d in self.sample_directions(self.popsize):
            model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
            sample_grad, sample_loss = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
            
            grad_mu += self.vectorize_state_dict(sample_grad)

            total_loss += sample_loss

        self.mean_eval = total_loss / self.popsize

        return grad_mu/self.popsize, grad_sigma/self.popsize, total_loss/self.popsize

    def estimate_grads_sym(self, x, y, eps, l_base, precomp_bounds=False):
        model_wrapper:BasicModelWrapper = self.model_wrapper

        grad_gsm = torch.zeros_like(self.center)
        grad_mu = torch.zeros_like(self.center)
        grad_sigma = torch.zeros_like(self.std)
        sig = self.std
        sig_sq = self.std ** 2
        total_loss = 0
        GSM_count = self.GSM_popsize

        if self.GSM_popsize == 0:
            # print("Evaluating natural gradient", type(model_wrapper))
            # center0 = self.vectorize_state_dict(model_wrapper.net.state_dict())
            # print("Delta center", (center0 - self.center).norm(2),flush=True)
            model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center))
            # center0 = self.vectorize_state_dict(model_wrapper.net.state_dict())
            # print("Delta center", (center0 - self.center).norm(2),flush=True)
            center_grad, loss_center = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
            grad_gsm = self.vectorize_state_dict(center_grad)
            GSM_count = 1
            total_loss = loss_center
            # print("Loss at center", loss_center)
            # print("Loss received", l_base)
            
        # for cnt, d in enumerate(self.sample_directions(self.popsize//2)):
            
        #     if cnt < self.GSM_popsize//2:
        #         model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
        #         sample_grad_plus, sample_loss_plus = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
        #         sample_grad_plus = self.vectorize_state_dict(sample_grad_plus)
        #         l_plus = sample_loss_plus
        #         # l_plus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)
        #         model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center - d))
        #         sample_grad_minus, sample_loss_minus = get_grad_value(model_wrapper,x,y,eps, self.grad_cleaner, self.args)
        #         sample_grad_minus = self.vectorize_state_dict(sample_grad_minus)
        #         l_minus = sample_loss_minus
        #         # l_minus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)

        #     else:
        #         model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center + d))
        #         l_plus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)
        #         model_wrapper.net.load_state_dict(self.state_dict_from_vector(self.center - d))
        #         l_minus = get_loss_value(model_wrapper,x,y,eps, precomp_bounds=precomp_bounds)
        #         sample_grad_plus = sample_grad_minus = 0

        #     gmu = d * (l_plus - l_minus)
        #     gsig = (d ** 2 - sig_sq) / sig * ((l_plus + l_minus) / 2 - l_base)
        #     if self.scale_grads:
        #         gmu = gmu / (l_plus + l_minus)
        #         gsig = gsig / l_base

        #     grad_mu += gmu
        #     grad_gsm += sample_grad_plus + sample_grad_minus
        #     grad_sigma += gsig

        #     total_loss += (l_plus + l_minus) / 2

        self.mean_eval = total_loss / self.popsize

        grad_center = 2 * grad_mu * self.PGPE_weight / self.popsize + grad_gsm / GSM_count

        return grad_center, grad_sigma*self.PGPE_weight/self.popsize, total_loss/self.popsize

    def sample_directions(self, count):

        for i in range(count):
            yield torch.randn_like(self.std) * self.std

    def descend(self, grad_mu:torch.Tensor, grad_sigma:torch.Tensor):

        if self.optimizer_center is not None:
            step = self.optimizer_center.descend(grad_mu)
        else:
            step = -grad_mu * self.center_lr
        self.grad_size = grad_mu.norm(2)
        self.grad_step_size = step.norm(2)
        # print('PGPE grad size',self.grad_size)
        # print('PGPE grad step size',self.grad_step_size)
        # print('PGPE center lr', self.center_lr)
        # print('PGPE opt lr', self.optimizer_center.stepsize)
        self.center += step
        
        self.total_change = (self.center - self.original_center).norm(2)

        old_std = self.std.clone()
        if self.optimizer_std is not None:
            step = self.optimizer_std.descend(grad_sigma)
        else:
            step = -grad_sigma * self.std_lr
        self.std += step
        self.std = torch.clip(self.std,(1-self.max_std_change)*old_std,(1+self.max_std_change)*old_std)
        self.std = torch.clip(self.std, self.std_min, self.std_max)

    def log_unstable(self, ibp_unstable=None, dp_unstable=None):
        if self.nep_log is not None:
            self.nep_log["result/ibp_unstable"].append(ibp_unstable)
            self.nep_log["result/dp_unstable"].append(dp_unstable)

    def log_with_neptune(self):
        if self.nep_log is not None:
            if self.lr_scheduler is not None:
                self.nep_log["result/lr"].append(self.optimizer_center.stepsize)
            self.nep_log["result/eps"].append(self.current_eps)
            self.nep_log["result/mean_eval"].append(self.mean_eval/self.args.fitness_factor)
            self.nep_log["result/mean_pgpe_std_mean"].append(self.std.mean())
            # self.nep_log["result/mean_pgpe_std_min"].append(self.searcher.status["stdev"].min())
            # self.nep_log["result/mean_pgpe_std_max"].append(self.searcher.status["stdev"].max())
            # self.nep_log["result/mean_pgpe_std_median"].append(self.searcher.status["stdev"].median())
            # print(self.searcher.status["stdev"][173:183].cpu().tolist())
            if hasattr(self,"delta") and self.delta is not None:
                self.nep_log["result/real_loss_delta"].append(self.delta)
                self.nep_log["result/loss_error"].append(self.loss_error)
            if hasattr(self,"grad_step_size") and self.grad_step_size is not None:
                self.nep_log["result/grad_step_size"].append(self.grad_step_size)
            if hasattr(self,"grad_size") and self.grad_size is not None:
                self.nep_log["result/grad_size"].append(self.grad_size)
            if hasattr(self, "total_change") and self.total_change is not None:
                self.nep_log["result/total_change"].append(self.total_change)

