import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import evotorch
from evotorch.neuroevolution import SupervisedNE
from evotorch.decorators import vectorized, on_aux_device, on_cuda
from evotorch.algorithms import PGPE
from evotorch.neuroevolution.net.misc import count_parameters, parameter_vector, fill_parameters
from AIDomains.concrete_layers import Normalization
from torch.optim.lr_scheduler import MultiStepLR
from evotorch.neuroevolution.neproblem import NEProblem
from evotorch import Problem, Solution
from evotorch.algorithms.searchalgorithm import SearchAlgorithm, SinglePopulationAlgorithmMixin
from evotorch.logging import PandasLogger, StdOutLogger

from typing import Callable, Iterable, List, Optional, Union
from utils import Identity, Statistics, seed_everything
from tqdm.auto import tqdm
import math
import copy
import numpy as np
import os
import sys
from time import time

sys.path.append('prima4complete/')
from src.utilities.config import IntermediateBoundsMethod

TMP_FILE_PREFIX = ""

StateDict = dict[str,torch.Tensor]

def add(sd1:StateDict,sd2:StateDict):
    all_keys = {*sd1.keys(), *sd2.keys()}
    res = {}

    for k in all_keys:
        if k not in sd1:
            res[k] = sd2[k].to('cpu')
        elif k not in sd2:
            res[k] = sd1[k].to('cpu')
        else:
            res[k] = sd1[k].to('cpu') + sd2[k].to('cpu')
    return res

def mul(sd1:StateDict,scalar):
    return {k:v*scalar for k,v in sd1.items()}

def L2_norm(sd:StateDict):
    norm = 0
    for k,v in sd.items():
        # print(k,v.shape if v is not None else None)
        norm += (v**2).sum()
        print(k,(v**2).sum().sqrt())
    return norm.sqrt()

# Note: dataloader cannot be created before _evaluate to enable parallelism
class CertifyNEProblem(NEProblem):
    '''
    Problem class for certified training.

    @param
        network: a torch net with/without a Normalization layer defined in AIDomains.concrete_layers
        model_wrapper: the abs_net contained must be connected to network (torch net)

    '''
    def __init__(self, network:nn.Module, dataset, model_wrapper, eps_scheduler, num_actors, num_gpus_per_actor, device, subbatch_size, args):
        # handle normalization layers separately to avoid updating normalization params
        # cannot split network into normalization and others, otherwise data range in model wrapper becomes buggy;
        if isinstance(network[0], Normalization):
            self.norm_state_dict = copy.deepcopy(network[0].state_dict())

        super().__init__(
            objective_sense="min",
            network=network,
            num_actors=num_actors,
            num_gpus_per_actor=num_gpus_per_actor,
            device=device,
            subbatch_size=subbatch_size,
        )
        self.network = network
        # self.mean_net = network.clone()
        self.dataset = dataset
        self.model_wrapper = model_wrapper
        self.current_eps = None
        self.eps_scheduler = eps_scheduler
        self.steps_taken = 0
        self.args = args
        self.interm_bounds = None
        self.loaded_bounds = False
        self.computed_bounds = False
        self.epochs = 0
        self.tmp_fn = "tmp_interm_bounds.pt"
        self.tmp_net_fn = "tmp_mean_net.pt"
        self.batch_repeat_counter = 0
        self.batch_repeat = args.repeat_batch

    def _restore_normalization_param(self, network=None):
        if network is None:
            network = self.network
        if isinstance(network[0], Normalization):
            network[0].load_state_dict(self.norm_state_dict)

    def make_dataloader(self):
        self.dataloader = DataLoader(self.dataset, batch_size=self.args.train_batch, shuffle=(not self.args.only_compute_one_step), num_workers=8)

    def _prepare(self, epoch_idx=0) -> None:
        if self.args.sync_batches_across_actors:
            seed_everything(self.args.random_seed + epoch_idx)
            print(os.getpid(),'rand seed',self.args.random_seed + epoch_idx)
        self.make_dataloader()
        self.reset_or_init()
        for bb in range(self.args.batch_idx):
            next(self.batch_iterator)
        if self.args.sync_batches_across_actors:
            seed_everything(self.args.random_seed + epoch_idx + os.getpid())
            print(os.getpid(),'rand seed',self.args.random_seed + epoch_idx + os.getpid())
    
    def next_eps(self):
        self.current_eps = self.eps_scheduler.getcurrent(self.steps_taken)
        self.steps_taken += 1
        return self.current_eps
    
    def sync_net(self, net:nn.Module):
        self.network.load_state_dict(net.state_dict())

    def reset_or_init(self):
        self.batch_iterator = iter(self.dataloader)

    def get_next_batch(self):
        '''
        run times should always be self.num_batches
        '''
        if not hasattr(self, "dataloader"):
            self._prepare(0)
            self.epochs = 0
            first_batch = True
        else:
            first_batch = False
        try:
            if first_batch:
                batch = next(self.batch_iterator)
            else:
                batch = next(self.batch_iterator)
                # batch = self._current_minibatch
                # for i in range(0):
                #     batch = next(self.batch_iterator)
        except:
            self.epochs += 1
            if self.args.sync_batches_across_actors:
                self._prepare(self.epochs)
            else:
                self.reset_or_init()
            batch = next(self.batch_iterator)
        self.batch_repeat_counter = 0
        return batch

    def get_current_batch(self):
        if self.batch_repeat_counter < self.batch_repeat:
            self.batch_repeat_counter += 1
            return self._current_minibatch
        else:
            x, y = self.get_next_batch()
            x, y = x.to(self.device), y.to(self.device)
            self._current_minibatch = (x, y)
            
            self.batch_repeat_counter = 1
            return self._current_minibatch

    def _evaluate_network(self, network: torch.nn.Module) -> Union[torch.Tensor, float]:
        '''
        Fill the current torch net with the solution and evaluate it on the current batch.
        '''
        eps = self.current_eps
        model_wrapper = self.model_wrapper
        x, y = self.get_current_batch()
        # print(f"recall new batch in worker: {y[:10]}")
        # print(f"recall new batch in worker: {x[0, 0, 5:15, 14]}")
        # self.sync_net(network)
        net_sd = copy.deepcopy(network.state_dict())
        # print(f"{[x.item() for x in self.network[-1].bias]}")
        # print(f"{[x.item() for x in network[-1].bias]}")

        # print("_evaluate_batch", 'bounds', self.interm_bounds is not None)
        if self.args.precomp_bounds:

            if self.args.compute_bounds_each_actor and not self.loaded_bounds:
                self.load_mean_net()
                self.model_wrapper.anet.set_activation_layers()
                self.compute_intermediate_bounds()
                self.loaded_bounds = True
        
        # print(f"{[x.item() for x in self.network[-1].bias]}")
        # print(f"{[x.item() for x in network[-1].bias]}")
        self.network.load_state_dict(net_sd)
        # print(f"{[x.item() for x in self.network[-1].bias]}")
        # print(f"{[x.item() for x in network[-1].bias]}")

        if self.args.precomp_bounds:
            if not self.loaded_bounds:
                self.load_intermediate_bounds()
            interm_bounds = self.interm_bounds
            if self.args.transform_bounds != "none" and interm_bounds is not None:
                interm_bounds = model_wrapper.anet.transform_bounds(interm_bounds, x, y, self.args.transform_bounds)
        else:
            interm_bounds = None

        if self.args.precomp_bounds and interm_bounds is None: # first batch
            return torch.zeros(1, device=x.device)

        # print('_evaluate_network with interm bounds', interm_bounds is not None)
        
        # restore normalization params because these should not be updated
        self._restore_normalization_param()
        # print("_evaluate_network")
        (loss, nat_loss, robust_loss), (nat_accu, robust_accu) = model_wrapper.compute_model_stat(x, y, eps, intermediate_bounds=interm_bounds)
        # if "RAY_DEDUP_LOGS" in os.environ and os.environ["RAY_DEDUP_LOGS"] == "0":
        #     print(f"loss: {loss: .4f}, steps: {self.steps_taken}, {[x.item() for x in self.network[-1].bias]}")
        # print(f"eps: {eps:.4f}, loss: {loss: .4f}, steps: {self.steps_taken}")
        # print(os.getpid(),y[:10],y.shape,loss)
        return loss * self.args.fitness_factor

    def compute_batch_loss(self, precomp_bounds=False):
        eps = self.current_eps
        model_wrapper = self.model_wrapper
        x, y = self._current_minibatch
        x, y = x.to(self.device), y.to(self.device)


        interm_bounds = None
        if self.args.precomp_bounds and precomp_bounds:
            interm_bounds = self.interm_bounds

        # print(eps,interm_bounds is not None,y[:10])
        # self._restore_normalization_param()
        # print("compute_batch_loss")
        (loss, nat_loss, robust_loss), (nat_accu, robust_accu) = model_wrapper.compute_model_stat(x, y, eps, intermediate_bounds=interm_bounds)
        if hasattr(model_wrapper,'anet'): model_wrapper.anet.reset_input_bounds()
        return loss

    def prepare_next_batch(self):
        # using a common data batch, generate them now and use them for the entire batch of solutions
        x, y = self.get_next_batch()
        # print(f"host process next batch: {x[0, 0, 5:15, 14]}")
        # print("prepare next batch",y[:10],flush=True)
        x, y = x.to(self.device), y.to(self.device)
        self._current_minibatch = (x, y)
        self.next_eps()
        self.loaded_bounds = False

    def _evaluate_batch(self, batch):
        # using a common data batch, generate them now and use them for the entire batch of solutions
        # print("_evaluate_batch",batch, 'bounds', self.interm_bounds is not None)
        if self.args.num_actors > 1 or not self.args.precomp_bounds:
            x, y = self.get_next_batch()
            # print(f"fetch new batch in worker: {x[0, 0, 5:15, 14]}")
            x, y = x.to(self.device), y.to(self.device)
            # print("eval batch",y[:10])
            self._current_minibatch = (x, y)
        self.next_eps()
        self.loaded_bounds = False
        self.computed_bounds = False
        # print("_evaluate_batch")
        return super()._evaluate_batch(batch)

    def get_unstable(self):
        with torch.no_grad():
            x, y = self._current_minibatch # self.get_next_batch()
            x, y = x.to(self.device), y.to(self.device)
            eps = self.current_eps
            # data_min, data_max = self.model_wrapper.data_min, self.model_wrapper.data_max
            # ibp_stability, dp_stability = self.model_wrapper.get_robust_stat_from_bounds(
            #                                                (x - eps).clamp(min=data_min), (x + eps).clamp(max=data_max),
            #                                                y, compute_stability=True)
            # self._restore_normalization_param()
            _, _, _, ibp_stability, dp_stability, _ = self.model_wrapper.get_robust_stat_from_input_noise(
                                                           eps, x, y, return_all=True)
        return ibp_stability, dp_stability
    
    def compute_intermediate_bounds(self, return_unstable=False):
        with torch.no_grad():
            x, y = self._current_minibatch # self.get_next_batch()
            # print(f"compute_bounds: {x[0, 0, 5:15, 14]}")
            x, y = x.to(self.device), y.to(self.device)
            # print("comp interm bounds",y[:10],flush=True)
            eps = self.current_eps
            # data_min, data_max = self.model_wrapper.data_min, self.model_wrapper.data_max
            # ibp_stability, dp_stability = self.model_wrapper.get_robust_stat_from_bounds(
            #                                                (x - eps).clamp(min=data_min), (x + eps).clamp(max=data_max),
            #                                                y, compute_stability=True)
            self.model_wrapper.anet.reset_input_bounds()
            # self._restore_normalization_param()
            _, _, _, ibp_stability, dp_stability, self.interm_bounds = self.model_wrapper.get_robust_stat_from_input_noise(eps, x, y, return_all=True, compute_bounds=True)
            if self.args.transform_bounds != "none":
                self.interm_act = self.model_wrapper.anet.get_activations(x, y)
                self.interm_bounds["activations"] = self.interm_act
            self.model_wrapper.anet.reset_input_bounds()
        if return_unstable:
            return self.interm_bounds, ibp_stability, dp_stability
        else:
            return self.interm_bounds

    def save_intermediate_bounds(self,fn=None):
        if self.interm_bounds is None: 
            print("WARNING: Trying to save interm bounds but there are None!")
            return
        if fn is None: fn = TMP_FILE_PREFIX + self.tmp_fn
        torch.save(self.interm_bounds, fn)

    def clean_intermediate_bounds(self,fn=None):
        if fn is None: fn = TMP_FILE_PREFIX + self.tmp_fn
        if os.path.exists(fn):
            os.remove(fn)
    
    def load_intermediate_bounds(self, fn=None):
        if fn is None: fn = TMP_FILE_PREFIX + self.tmp_fn
        if os.path.exists(fn):
            self.interm_bounds = torch.load(fn)
            self.loaded_bounds = True
        else:
            print(f"WARNING: Trying to load interm bounds but file {fn} does not exist!")

        if self.interm_bounds is not None:
            self.model_wrapper.anet.set_activation_layers()
        #     # print(self.device,"loaded dict",self.interm_bounds.keys())
        #     # print("loaded dict",self.interm_bounds['layer_ids'])
        #     self.model_wrapper.anet.set_activation_layers()
        #     # print("id to layer",self.model_wrapper.anet.layer_id_to_layer.keys())
        #     # print("get lay ids",self.model_wrapper.anet.get_layer_ids())
        #     self.interm_bounds = self.model_wrapper.anet.translate_layer_ids(self.interm_bounds)
        #     # print("after translate",self.interm_bounds.keys())
        #     self.model_wrapper.anet.reset_input_bounds()
        #     self.model_wrapper.anet.set_bounds_from_dict(self.interm_bounds)

        return self.interm_bounds

    def load_mean_net(self,fn=None):
        if fn is None: fn = TMP_FILE_PREFIX + self.tmp_net_fn

        if os.path.exists(fn):
            self.network.load_state_dict(torch.load(fn))
        else:
            print(f"WARNING: Trying to load mean net but file {fn} does not exist!")

    def save_mean_net(self,fn=None):
        if fn is None: fn = TMP_FILE_PREFIX + self.tmp_net_fn
        torch.save(self.network.state_dict(), fn)

class BasicEvoWrapper():
    def __init__(self, train_loader, torch_net, model_wrapper, eps_scheduler, num_actors:int, device, args, subbatch_size:int=None, num_gpus_per_actor="max", nep_log=None):
        self.problem = CertifyNEProblem(
            torch_net,
            train_loader.dataset,
            model_wrapper,
            eps_scheduler,
            num_actors,
            num_gpus_per_actor,
            device,
            subbatch_size,
            args
        )
        self.torch_net = torch_net
        self.model_wrapper = model_wrapper
        self.args = args
        self.num_batches = len(train_loader)
        self.nep_log = nep_log
        self.delta = None
        self.searcher: SearchAlgorithm = None
        self.loss_error = None
        self.init_center = None
        self.model_wrapper.net.to(device)


    def adjust_search_based_on_eps(self, eps):
        pass

    def get_solution(self):
        raise NotImplementedError
    
    def step_lr(self):
        pass

    def next_hyperparam(self):
        pass

    def run(self, epoch_idx, steps):
        # if epoch_idx: self.problem._prepare(epoch_idx)
        if self.init_center is None:
            self.init_center = self.searcher.status['center'].clone()
        pbar = tqdm(range(steps))
        total_delta = 0
        if self.args.neptune_id is not None:
            self.problem.tmp_fn = f"tmp_interm_bounds_{self.args.neptune_id}.pt"
            self.problem.tmp_net_fn = f"tmp_mean_net_{self.args.neptune_id}.pt"
        self.problem.clean_intermediate_bounds()
        for i in pbar:
            if i == self.args.max_batches: break
            # eps = self.problem.next_eps()
            times = []
            names = []
            times.append(time())
            eps = self.problem.eps_scheduler.getcurrent(epoch_idx * self.num_batches + i)
            # print(f"main eps: {eps:.4f}, epoch_idx{epoch_idx}, step: {i}")
            self.problem.current_eps = eps
            self.adjust_search_based_on_eps(eps)
            self.next_hyperparam()
            this_first_batch = (epoch_idx == self.args.start_epoch and i == 0)
            precomp_bounds = self.args.precomp_bounds and (not this_first_batch or self.args.num_actors <= 1)
            # print(epoch_idx, i, "precomp_bounds",precomp_bounds)
            times.append(time()) # 1 1.5020370483398438e-05 

            self.problem.save_mean_net()

            # print(i,precomp_bounds,self.args.precomp_bounds,flush=True)
            if (not this_first_batch or self.args.num_actors <= 1):
                self.problem.prepare_next_batch()
                # print('after prepare next batch')
            old_center = self.searcher.status['center'].clone()
            times.append(time()) # 2 0.002388477325439453
            
            if self.args.only_compute_one_step:
                old_state_dict = copy.deepcopy(self.problem.network.state_dict())
                for k in old_state_dict:
                    old_state_dict[k] = old_state_dict[k].to(self.problem.device)
                # print(L2_norm(old_state_dict))
            # torch.save(self.problem.network.state_dict(),f'old_state_precomp_{self.args.precomp_bounds}.pt')
            if precomp_bounds:
                _, ibp_unstable, dp_unstable = self.problem.compute_intermediate_bounds(return_unstable=True)
                self.problem.save_intermediate_bounds()
                # with torch.no_grad():
                #     loss_before_ = self.problem.compute_batch_loss(precomp_bounds=False)
                #     print(f'loss before train no b: {loss_before_}')
                # print('after compute interm',self.problem.interm_bounds is not None)
            else:
                ibp_unstable, dp_unstable = None, None
            times.append(time()) # 3 0.20166277885437012

            if (not this_first_batch or self.args.num_actors <= 1):
                with torch.no_grad():
                    loss_before = self.problem.compute_batch_loss(precomp_bounds=False)
                    if hasattr(self.problem.model_wrapper,'optimizer'):
                        self.problem.model_wrapper.optimizer.backsubstitution_config.intermediate_bounds_method = IntermediateBoundsMethod["dp"]
            times.append(time()) # 4 0.19310712814331055 --> now lower
            # if hasattr(self.problem, "_current_minibatch") and self.problem._current_minibatch is not None:
            #     x, y = self.problem._current_minibatch # self.get_next_batch()
            #     print(f"pre search mini batch: {x[0, 0, 5:1e5, 14]}")
            # print('before',self.searcher.status["center"][-1].item(),(self.searcher.status["center"]**2).sum().sqrt())
            center_before = self.searcher.status["center"].clone()
            self.searcher.run(1)
            center_after = self.searcher.status["center"].clone()
            # print((center_before - center_after).norm(1))
            # print('after ',self.searcher.status["center"][-1].item(),(self.searcher.status["center"]**2).sum().sqrt())
            times.append(time()) # 5 1.0690631866455078 
            # if hasattr(self.problem, "_current_minibatch") and self.problem._current_minibatch is not None:
            if this_first_batch and self.args.num_actors > 1:
                self.problem.prepare_next_batch()

            # print("self.args.precomp_bounds",self.args.precomp_bounds,flush=True)
            with torch.no_grad():
                if self.args.precomp_bounds and not precomp_bounds:
                    self.problem.compute_intermediate_bounds()
                loss_before = self.problem.compute_batch_loss(precomp_bounds=False)
                if hasattr(self.problem.model_wrapper,'optimizer'):
                    self.problem.model_wrapper.optimizer.backsubstitution_config.intermediate_bounds_method = IntermediateBoundsMethod["dp"]

            self.problem.sync_net(self.get_solution())
            self.problem._restore_normalization_param()
            times.append(time()) # 6 0.0019032955169677734 

            with torch.no_grad():
                loss_new_bounds = self.problem.compute_batch_loss(precomp_bounds=False)
                self.delta = loss_new_bounds - loss_before
                total_delta += self.delta
                # print(
                #       f'loss before train: {loss_before}\n'
                #     #   f'loss old bounds: {loss_old_bounds}\n'
                #     #   f'loss old bounds no b: {loss_old_bounds_}\n'
                #       f'loss after train: {loss_new_bounds}\n'
                #       f'delta: {self.delta}\n'
                #       f'total delta: {total_delta}\n'
                #     #   f'loss new bounds no b: {loss_new_bounds_}\n'
                #       ,flush=True)
            times.append(time()) # 7 0.19286394119262695 
            
            if self.args.only_compute_one_step:
                new_state_dict = copy.deepcopy(self.problem.network.state_dict())
                gradient = add(new_state_dict, mul(old_state_dict,-1))
                print(L2_norm(new_state_dict))
                print(L2_norm(old_state_dict))
                print(L2_norm(gradient))
                model_dir = os.path.dirname(self.args.load_model)
                fname = f'pgpe_{self.model_wrapper.name}_B{self.args.batch_idx}_gradient.pt'
                torch.save(gradient,os.path.join(model_dir,fname))
                print('saved grad',model_dir,fname)
                exit(0)

            if precomp_bounds:
                with torch.no_grad():
                    self.problem.load_intermediate_bounds()
                    loss_old_bounds = self.problem.compute_batch_loss(precomp_bounds=True)
                    self.problem.model_wrapper.optimizer.backsubstitution_config.intermediate_bounds_method = IntermediateBoundsMethod["dp"]

                    # loss_old_bounds_ = self.problem.compute_batch_loss(precomp_bounds=False)
                    # self.problem.compute_intermediate_bounds()
                    # print(f"old bound loss :{loss_old_bounds}, new bounds loss: {loss_new_bounds}")

                    # loss_new_bounds_ = self.problem.compute_batch_loss(precomp_bounds=False)
                    self.loss_error = loss_new_bounds - loss_old_bounds
                # print(
                #       f'loss before train: {loss_before}\n'
                #     #   f'loss old bounds: {loss_old_bounds}\n'
                #     #   f'loss old bounds no b: {loss_old_bounds_}\n'
                #       f'loss after train: {loss_new_bounds}\n'
                #       f'delta: {self.delta}\n'
                #       f'total delta: {total_delta}\n'
                #     #   f'loss new bounds no b: {loss_new_bounds_}\n'
                #       ,flush=True)
            times.append(time()) # 8 0.010828018188476562
            with torch.no_grad():
                new_center = self.searcher.status['center'].clone()
                grad = new_center - old_center
                self.grad_step_size = grad.flatten().norm(2).item()
                self.total_change = (new_center - self.init_center).flatten().norm(2).item()

                
            if self.args.log_unstable or precomp_bounds:
                self.log_unstable(ibp_unstable, dp_unstable)
            times.append(time()) # 10 0.30277061462402344
            self.log_with_neptune()
            times.append(time()) # 11 1.0041282176971436  
            pbar.set_postfix_str(f"Mean eval = {self.searcher.status['mean_eval']/self.args.fitness_factor:.4f}, Total delta = {total_delta:.6f}")
            times.append(time()) # 12 0.000385284423828125 
            # print('times')
            # for i in range(len(times)-1):
            #     print(i+1,times[i+1]-times[i])
            # print('total_per_batch',times[-1]-times[0],flush=True)            


    def log_with_neptune(self):
        pass

    def log_unstable(self, ibp_unstable=None, dp_unstable=None):
        if self.nep_log is not None:
            if ibp_unstable is None or dp_unstable is None:
                ibp_unstable, dp_unstable = self.problem.get_unstable()
            self.nep_log["result/ibp_unstable"].append(ibp_unstable)
            self.nep_log["result/dp_unstable"].append(dp_unstable)
            
    def train_one_epoch(self, epoch_idx):
        with torch.no_grad():
            self.run(epoch_idx, self.num_batches)

    def calc_mutation_power(self, torch_net, verbose:bool=True):
        # return shape: (all_params, )
        with torch.no_grad():
            all_vectors = []
            sigmas = []
            for p in torch_net.parameters():
                sigma = torch.std(p.view(-1))
                sigmas.append(round(sigma.item(), 3))
                all_vectors.append(sigma.repeat(p.numel()))
            sigma_vec = torch.nan_to_num(torch.cat(all_vectors), nan=0).to(self.problem.device)
            if verbose:
                print("Initializing with std:", sigma_vec)
            return sigma_vec.clone()
        

class PGPEEvoWrapper(BasicEvoWrapper):
    '''
    torch_net must be connected to the abstract net wrapped in model_wrapper.

    If use_current_std = True, std_init will be overriden by the std of current params in the torch_net for every param group.
    '''
    def __init__(self, train_loader, torch_net, model_wrapper, eps_scheduler, center_lr, std_lr,  num_actors:int, device, args, popsize:int=1000, std_init:float=1e-2, use_current_std:bool=True, std_min=None, std_max=None, subbatch_size:int=None, num_gpus_per_actor="max", logging:bool=False, nep_log=None, optimizer="adam"):
        # record current params to initialize the searcher solution
        init_params = parameter_vector(torch_net).clone().detach()
        
        super().__init__(train_loader, torch_net, model_wrapper, eps_scheduler, num_actors, device, args, subbatch_size, num_gpus_per_actor, nep_log)

        searcher = PGPE(
            self.problem,
            center_init=0, # will be replaced by the current init param
            stdev_init=std_init, # Initial radius of the search distribution
            center_learning_rate=center_lr, # Learning rate used by adam optimizer
            stdev_learning_rate=std_lr, # Learning rate for the standard deviation
            stdev_min=std_init, # Lock std during eps annealing
            stdev_max=std_init, # Lock std during eps annealing
            popsize=popsize, # Number of solutions sampled per iteration
            distributed=True, # Gradients are computed locally at actors and averaged
            optimizer=optimizer, # Using the adam optimizer
            ranking_method=None, # No rank-based fitness shaping is used
            symmetric=True, # The default and helpful to get
        )

        # set initial solution to current params in the torch net
        if not use_current_std:
            searcher._distribution = searcher._distribution.modified_copy(mu=init_params)
        else:
            searcher._distribution = searcher._distribution.modified_copy(mu=init_params, sigma=self.calc_mutation_power(torch_net))

        opt_state_path = os.path.join(args.save_root,'optimizer.ckpt')
        print("opt path",opt_state_path,flush=True)
        if args.resume_train:
            if os.path.exists(opt_state_path):
                searcher._optimizer._optim.load_state_dict(torch.load(opt_state_path))
                # print(args.device,list(optimizer.state_dict()['state'].keys()))
                # for kk in list(optimizer.state_dict()['state'].keys()):
                #     print(args.device,kk,optimizer.state_dict()['state'][kk]['exp_avg'].shape)
                # print(args.device,list(optimizer.state_dict()['param_groups']))
                # print(args.device,list(optimizer.state_dict().keys()))
                print(args.device,"loaded optimizer state:",opt_state_path,flush=True)
            # else:
            for group in searcher._optimizer._optim.param_groups:
                group.setdefault('initial_lr', group['lr'])

        self.lr_scheduler = MultiStepLR(searcher._optimizer._optim, args.lr_milestones, gamma=args.lr_decay_factor, last_epoch=args.start_epoch-1) if optimizer in ["adam", "sgd"] else None
        # self.lr_scheduler = MultiStepLR(searcher._optimizer, args.lr_milestones, gamma=args.lr_decay_factor)

        self.searcher = searcher
        self.std_locked = True
        if logging:
            StdOutLogger(self.searcher, interval=1)



    def step_lr(self):
        if self.lr_scheduler is not None: self.lr_scheduler.step()

    def get_solution(self):
        return self.problem.make_net(self.searcher.status["center"])
    
    def adjust_search_based_on_eps(self, eps):
        if eps == self.args.train_eps and self.std_locked:
            self.searcher._stdev_max = self.args.std_max
            self.searcher._stdev_min = self.args.std_min
            self.std_locked = False

    def log_with_neptune(self):
        if self.nep_log is not None:
            if self.lr_scheduler is not None:
                self.nep_log["result/lr"].append(self.lr_scheduler.get_last_lr()[0])
            self.nep_log["result/eps"].append(self.problem.current_eps)
            self.nep_log["result/mean_eval"].append(self.searcher.status["mean_eval"]/self.args.fitness_factor)
            self.nep_log["result/mean_pgpe_std_mean"].append(self.searcher.status["stdev"].mean())
            # self.nep_log["result/mean_pgpe_std_min"].append(self.searcher.status["stdev"].min())
            # self.nep_log["result/mean_pgpe_std_max"].append(self.searcher.status["stdev"].max())
            # self.nep_log["result/mean_pgpe_std_median"].append(self.searcher.status["stdev"].median())
            # print(self.searcher.status["stdev"][173:183].cpu().tolist())
            if self.delta is not None:
                self.nep_log["result/real_loss_delta"].append(self.delta)
                self.nep_log["result/loss_error"].append(self.loss_error)
            if hasattr(self,"grad_step_size") and self.grad_step_size is not None:
                self.nep_log["result/grad_step_size"].append(self.grad_step_size)
            if hasattr(self, "total_change") and self.total_change is not None:
                self.nep_log["result/total_change"].append(self.total_change)



# TODO: define evo wrapper for simple ga
# Note: simple ga keeps multiple copies, needs to enforce freezed normalization for all.
class GeneticsEvoWrapper(BasicEvoWrapper):
    def __init__(self, train_loader, torch_net, model_wrapper, eps_scheduler, popsize, num_elites, num_parents, mutation_power, pert_space_ratio, mp_scheduler, psr_scheduler, popsize_scheduler, num_actors: int, device, args, use_current_std:bool=False, subbatch_size: int = None, num_gpus_per_actor="max", nep_log=None):
        super().__init__(train_loader, torch_net, model_wrapper, eps_scheduler, num_actors, device, args, subbatch_size, num_gpus_per_actor, nep_log)
        self.searcher = GA(
            self.problem,
            popsize=popsize,
            num_elites=num_elites,
            num_parents=num_parents,
            mutation_power=mutation_power,
            pert_space_ratio=pert_space_ratio,
            mp_scheduler=mp_scheduler,
            psr_scheduler=psr_scheduler,
            popsize_scheduler=popsize_scheduler
        )
        if use_current_std:
            self.searcher._mutation_power = self.calc_mutation_power(torch_net)
        self.searcher.init_with_baseline(parameter_vector(torch_net).clone().to(self.problem.device).detach(), noise_scale_factor=1)
        self.num_batches = max(math.ceil(self.problem.solution_length / self.searcher._num_param_perturbed), self.num_batches)
        self.searcher._mp_scheduler.start_epoch *= self.num_batches
        self.searcher._mp_scheduler.end_epoch *= self.num_batches

    def next_hyperparam(self):
        self.searcher.next_hyperparam()

    def get_solution(self):
        return self.problem.make_net(self.searcher.population[0])

    def log_with_neptune(self):
        if self.nep_log is not None:
            self.nep_log["result/avg_mutation_power"].append(self.searcher._mutation_power.mean() * self.searcher._mutation_factor)
            self.nep_log["result/num_param_perturb"].append(self.searcher._num_param_perturbed)
            self.nep_log["result/popsize"].append(self.searcher._popsize)
            self.nep_log["result/pop_best_eval"].append(self.searcher.status["pop_best_eval"])


# Genetic Algorithm Searcher
class GA(SearchAlgorithm, SinglePopulationAlgorithmMixin):
    def __init__(
        self,
        problem: Problem,
        popsize: int,  # Total population size n
        num_elites: int,  # Number of elites that survive each generation e
        num_parents: int,  # Number of parents from which to generate children
        mutation_power: float,  # Scale of gaussian noise used to generate children
        pert_space_ratio:float=1, # ratio of perturbation space
        mp_scheduler=None, # scheduler for mutation power
        psr_scheduler=None, # scheduler for pert space ratio
        popsize_scheduler=None, # scheduler for popsize
    ):
        # Call the __init__(...) method of the superclass
        SearchAlgorithm.__init__(
            self,
            # Problem to work on:
            problem,
            # The remaining keyword arguments are for registering
            # the status getter methods.
            # The result of these status getter methods will
            # automatically be shown in the status dictionary.
            pop_best=self._get_pop_best,
            pop_best_eval=self._get_pop_best_eval,
        )
        SinglePopulationAlgorithmMixin.__init__(
            self,
        )

        # Store the hyperparameters
        self._popsize = int(popsize)
        self._num_elites = int(num_elites)
        self._num_parents = int(num_parents)
        # TODO: make heterogeneous std possible
        self._mutation_power = float(mutation_power) * torch.ones(self.problem.solution_length, device=self._problem.device)
        self._mutation_factor = 1
        self._steps_taken = 0
        self._pert_space_ratio = pert_space_ratio

        self._num_param_perturbed = math.ceil(pert_space_ratio * self.problem.solution_length)

        # Generate the initial population -- note that this uses the problem's initial bounds as a uniform hyper-cube.
        self._population = self._problem.generate_batch(self._popsize)

        # The following variable stores a copy of the current population's
        # best solution
        self._pop_best: Optional[Solution] = None

        # limit pert space, reduce mutation power and increase popsize after sufficient iterations
        # TODO: add a scheduler for theses hyperparameters
        self._mp_scheduler = mp_scheduler
        self._psr_scheduler = psr_scheduler
        self._popsize_scheduler = popsize_scheduler

    def _get_pop_best(self):
        return self._pop_best

    def _get_pop_best_eval(self):
        return self._pop_best.get_evals()
    
    def next_hyperparam(self):
        if self._mp_scheduler is not None:
            self._mutation_factor = self._mp_scheduler.getcurrent(self._steps_taken)
        if self._psr_scheduler is not None:
            self._pert_space_ratio = self._psr_scheduler.getcurrent(self._steps_taken)
        if self._popsize_scheduler is not None:
            new_popsize = self._popsize_scheduler.getcurrent(self._steps_taken)
            if new_popsize > self._popsize:
                self._popsize = new_popsize
                self.increase_popsize(new_popsize)
    
    @property
    def population(self):
        return self._population
    
    def increase_popsize(self, new_popsize:int):
        assert new_popsize > self._popsize, f"New popsize {new_popsize} should be larger than the current popsize {self._popsize}"
        self._popsize = new_popsize
        new_population = self._problem.generate_batch(self._popsize, empty=True)
        new_population.access_values()[:self._num_elites] = self._population.access_values()[:self._num_elites]
        self._population = new_population

    def init_with_baseline(self, base_vec:torch.Tensor, noise_scale_factor:float=1):
        assert self._population.access_values().shape[-1] == len(base_vec) and len(base_vec.shape)==1, f"Shape unmatch! Expected: {self._population.access_values().shape[-1]}; Given: {base_vec.shape}"
        seeds = self._population.access_values()
        # initialize with random perturbations of the base model
        pert = self._mutation_power * noise_scale_factor * self.problem.make_gaussian(len(seeds), self.problem.solution_length)
        pert[0] = 0 # keep an original copy
        self._population.access_values()[:] = base_vec.repeat(len(seeds), 1) + pert

    def _step(self):
        # If this is the very first iteration, this means that we have an unevaluated population.
        if self._steps_taken == 0:
            # Evaluate the population
            self.problem.evaluate(self._population)
            # Sort the population
            self._population = self._population[self._population.argsort()]

        # Select the parents.
        parents = self._population[: self._num_parents]

        # Pick a random parent for each child
        num_children = self._popsize - self._num_elites
        parent_indices = self.problem.make_randint(num_children, n=self._num_parents)
        parent_values = parents.values[parent_indices]

        # Add gaussian noise
        # perturbation = self._mutation_power * self.problem.make_gaussian(num_children, self.problem.solution_length)
        # perturbation *= torch.bernoulli(self._pert_space_ratio * torch.ones((1, perturbation.shape[-1]), device=perturbation.device))

        num_param_perturbed = math.ceil(self._pert_space_ratio * self.problem.solution_length)
        self._num_param_perturbed = num_param_perturbed
        perturbation = torch.zeros(num_children, self.problem.solution_length, device=parent_values.device)
        selected_idx = np.random.choice(self.problem.solution_length, num_param_perturbed, replace=False)
        perturbation[:, selected_idx] = self._mutation_factor * self._mutation_power[selected_idx] * self.problem.make_gaussian(num_children, num_param_perturbed)

        child_values = (
            parent_values
            + perturbation
        )

        # Overwrite all the non-elite solutions with the new generation
        self._population.access_values()[self._num_elites :] = child_values

        # Evaluate and sort the new population
        self.problem.evaluate(self._population)
        self._population = self._population[self._population.argsort()]

        # Store a copy of the best solution, for reporting
        # and analysis purposes
        self._pop_best = self._population[0].clone()
        self._steps_taken += 1
     