from model.critic import Critic
from utils.logger import log_to_tb_train
from utils.utils import clip_grad_norms
from .rollout import rollout
from utils.utils import set_seed
from dataset.generate_dataset import sample_batch_task, get_train_set
from data_scheduler.plr_scheduler import PLRScheduler
from dataset.cec_test_func import *
import numpy as np
from .task import TaskForTrain
from env import SubprocVectorEnv,DummyVectorEnv

from pbo_env import L2E_env,MadDE,sep_CMA_ES,PSO,DE
from expr.tokenizer import MyTokenizer
import torch
from tqdm import tqdm
from utils.utils import torch_load_cpu, get_inner_model, get_surrogate_gbest, get_surrogate_gbest_subproc
import os

class Data_Memory():
    def __init__(self) -> None:
        self.teacher_cost=[]
        self.stu_cost=[]
        self.baseline_cost=[]
        self.gap=[]
        self.baseline_gap=[]
        self.expr=[]
    
    def clear(self):
        del self.teacher_cost[:]
        del self.stu_cost[:]
        del self.baseline_cost[:]
        del self.gap[:]
        del self.baseline_gap[:]
        del self.expr[:]

# memory for recording transition during training process
class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.gap_rewards=[]
        self.b_rewards=[]

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.gap_rewards[:]
        del self.b_rewards[:]

class trainer(object):
    def __init__(self,model,opts) -> None:
        self.actor=model
        self.critic=Critic(opts)
        self.opts=opts
        self.optimizer=torch.optim.Adam([{'params':self.actor.parameters(),'lr':opts.lr}] + 
                                        [{'params':self.critic.parameters(),'lr':opts.lr_critic}])
        # figure out the lr schedule
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, opts.lr_decay, last_epoch=-1,)

        if opts.use_cuda:
            # move to cuda
            self.actor.to(opts.device)
            self.critic.to(opts.device)
        # load model to cuda
        # move optimizer's data onto chosen device
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(opts.device)

    def _infer_base_id(self, id_val):
        """
        Map a dynamic curriculum ID (rand_/mut_...) back to its base task ID.
        Works with base IDs that contain underscores (e.g., 'Bent_cigar').
        """
        if not hasattr(self, 'raw_surrogate_gbest') or self.raw_surrogate_gbest is None:
            return id_val
        keys = self.raw_surrogate_gbest.keys()
        s = str(id_val)
        if s in keys:
            return s
        # Prefer longest substring match
        matches = [k for k in keys if k in s]
        if matches:
            return max(matches, key=len)
        return id_val

    def set_training(self):
        torch.set_grad_enabled(True)
        self.actor.train()
        self.critic.train()
    
    def set_evaling(self):
        torch.set_grad_enabled(False)
        self.actor.eval()
        self.critic.eval()

    def merge_action_dicts(self, action_dicts):
        if not action_dicts:
            return {}
        T = len(action_dicts)
        B = action_dicts[0]['seq'].shape[0]
        merged = {}
        merged['seq'] = torch.cat([d['seq'] for d in action_dicts], dim=0)
        merged['c_seq'] = torch.cat([d['c_seq'] for d in action_dicts], dim=0)
        max_steps = max(len(d['x_in']) for d in action_dicts)
        for key in ['x_in', 'mask', 'working_index', 'position', 'c_index', 'filter_index']:
            merged[key] = []
        for i in range(max_steps):
            x_in_list, mask_list, wi_list, pos_list, ci_list, fi_list = [], [], [], [], [], []
            for t in range(T):
                d = action_dicts[t]
                if i < len(d['x_in']):
                    x_in_list.append(d['x_in'][i])
                    mask_list.append(d['mask'][i])
                    wi_list.append(d['working_index'][i] + t * B)
                    pos_list.append(d['position'][i])
                    ci_list.append(d['c_index'][i])
                    fi_list.append(d['filter_index'][i])
            merged['x_in'].append(torch.cat(x_in_list, dim=0))
            merged['mask'].append(torch.cat(mask_list, dim=0))
            merged['working_index'].append(torch.cat(wi_list, dim=0))
            merged['position'].append(torch.cat(pos_list, dim=0))
            merged['c_index'].append(np.concatenate(ci_list, axis=0))
            merged['filter_index'].append(np.concatenate(fi_list, axis=0))
        return merged

    # load model from load_path
    def load(self, load_path):

        assert load_path is not None
        load_data = torch_load_cpu(load_path)

        # load data for actor
        model_actor = get_inner_model(self.actor)
        model_actor.load_state_dict({**model_actor.state_dict(), **load_data.get('actor', {})})

        if not self.opts.test:
            # load data for critic
            model_critic=get_inner_model(self.critic)
            model_critic.load_state_dict({**model_critic.state_dict(), **load_data.get('critic', {})})
            # load data for optimizer
            self.optimizer.load_state_dict(load_data['optimizer'])
            # load data for torch and cuda
            torch.set_rng_state(load_data['rng_state'])
            if self.opts.use_cuda:
                torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
        # done
        print(' [*] Loading data from {}'.format(load_path))

    # save trained model
    def save(self, epoch):
        print('Saving model and state...')
        torch.save(
            {
                'actor': get_inner_model(self.actor).state_dict(),
                'critic':get_inner_model(self.critic).state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state_all(),
            },
            os.path.join(self.opts.save_dir, 'epoch-{}.pt'.format(epoch))
        )

    # inference for training
    def start_training(self,tb_logger):
        opts=self.opts
        
        # parallel vector
        self.vector_env=SubprocVectorEnv if opts.is_linux else DummyVectorEnv
        
        # construct the dataset
        set_seed(42)
        train_set,train_pro_id=get_train_set(self.opts)
        
        # construct parallel environment
        # learning_env
        learning_env_list=[lambda e=train_set[0]: L2E_env(dim=opts.dim,ps=opts.population_size,problem=e,max_x=opts.max_x,min_x=opts.min_x,max_fes=opts.max_fes,boarder_method=opts.boarder_method) for i in range(opts.batch_size)]
        learning_env=self.vector_env(learning_env_list)
        # teacher_env
        if self.opts.teacher=='madde':
            if self.opts.tea_step == 'step':
                madde_maxfes = round((opts.max_fes / opts.population_size) * (4 + 2 * opts.dim * opts.dim) / 2)
            else:
                madde_maxfes = opts.max_fes
            teacher_env_list=[lambda e=copy.deepcopy(train_set[0]): MadDE(dim=opts.dim,problem=e,max_x=opts.max_x,min_x=opts.min_x,max_fes=madde_maxfes) for i in range(opts.batch_size)]
            
        elif self.opts.teacher=='cmaes':
            teacher_env_list=[lambda e=copy.deepcopy(train_set[0]): sep_CMA_ES(dim=opts.dim,problem=e,max_x=opts.max_x,min_x=opts.min_x,max_fes=opts.max_fes,sigma=opts.cmaes_sigma) for i in range(opts.batch_size)]
        elif self.opts.teacher=='pso':
            teacher_env_list=[lambda e=None: PSO(ps=opts.population_size,dim=opts.dim,max_fes=opts.max_fes,min_x=opts.min_x,max_x=opts.max_x,pho=0.2) for i in range(opts.batch_size)]
        elif self.opts.teacher=='de':
            teacher_env_list=[lambda e=None: DE(dim=opts.dim,ps=opts.population_size,min_x=opts.min_x,max_x=opts.max_x,max_fes=opts.max_fes) for i in range(opts.batch_size)]
        else:
            assert True, f'The selecting {self.opts.teacher} teacher is currently not supported!!'
        teacher_env=self.vector_env(teacher_env_list)
        # random_env (for comparison)
        random_env_list=[lambda e=train_set[0]: L2E_env(dim=opts.dim,ps=opts.population_size,problem=e,max_x=opts.max_x,min_x=opts.min_x,max_fes=opts.max_fes,boarder_method=opts.boarder_method) for i in range(opts.batch_size)]
        random_env=self.vector_env(random_env_list)
        
        # curriculum store initialization
        if getattr(opts, 'curriculum', False) or getattr(opts, 'mab_curriculum', False):
            # Provide base task IDs for robust base_id/origin_id parsing (supports underscores).
            self.scheduler = PLRScheduler(opts, base_names=train_pro_id)

        # get surrogate gbest and init cost
        self.surrogate_gbest, self.surrogate_gworst = get_surrogate_gbest_subproc(train_set,train_pro_id,opts.batch_size,seed=999,fes=opts.max_fes, opts=opts)
        self.raw_surrogate_gbest = copy.deepcopy(self.surrogate_gbest)
        self.raw_surrogate_gworst = copy.deepcopy(self.surrogate_gworst)
        for id in self.surrogate_gbest:
            self.surrogate_gbest[id] = 0.0

        task=TaskForTrain(learning_env,teacher_env,random_env,opts.batch_size,opts)
        tokenizer=MyTokenizer()

        update_step=0

        test_ratio_list=[]

        epoch_len=opts.epoch_len

        # begin training 
        for epoch in range(opts.epoch_start,opts.epoch_end):
            self.lr_scheduler.step(epoch)
            
            self.set_training()
            # logging
            print('\n\n')
            print("|",format(f" Training epoch {epoch} ","*^60"),"|")
            print("Training with RNN lr={:.3e} for run {}".format(self.optimizer.param_groups[0]['lr'], opts.run_name) , flush=True)
            # start training
            epoch_step=epoch_len * (opts.max_fes // opts.population_size // opts.skip_step // opts.n_step) * opts.k_epoch
            pbar = tqdm(total = epoch_step, desc = 'training',
                        bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}')
            

            total_gap=0
            
            for b in range(epoch_len):
                instances, ids = None, None
                is_replay = False
                if getattr(opts, 'curriculum', False) or getattr(opts, 'mab_curriculum', False):
                    instances, ids, is_replay = self.scheduler.ask(opts.batch_size)
                
                batch_step,bat_gap,data_memory=self.train_batch(task,update_step,tokenizer,pbar,epoch,tb_logger,instances=instances,ids=ids,is_replay=is_replay)
                
                # Update global step count
                update_step += batch_step

                # Log training dynamics
                if getattr(opts, 'curriculum', False) or getattr(opts, 'mab_curriculum', False):
                    final_stu_costs = data_memory.stu_cost[-1]
                    final_tea_costs = data_memory.teacher_cost[-1]
                    eps = 1e-12
                    log_gaps = [np.log10(max(s, 0) + eps) - np.log10(max(t, 0) + eps) for s, t in zip(final_stu_costs, final_tea_costs)]
                    avg_log_gap = np.mean(log_gaps)
                    pass_rate = np.mean([1.0 if c < 1e-5 else 0.0 for c in final_stu_costs])
                    
                    try:
                        import wandb
                        prefix = "dynamics/replay_" if is_replay else "dynamics/random_"
                        wandb_log_dict = {
                            f"{prefix}avg_log_gap": avg_log_gap,
                            f"{prefix}pass_rate": pass_rate
                        }
                        if hasattr(opts, 'no_wandb_step') and opts.no_wandb_step:
                            wandb.log(wandb_log_dict)
                        else:
                            wandb.log(wandb_log_dict, step=update_step)
                    except ImportError:
                        pass

                # Update curriculum learning score (Tell)
                if getattr(opts, 'curriculum', False) and data_memory is not None:
                    final_lps = []
                    for i in range(len(ids)):
                        final_stu = data_memory.stu_cost[-1][i]
                        final_tea = data_memory.teacher_cost[-1][i]
                        eps = 1e-12
                        log_gap = np.log10(max(final_stu, 0) + eps) - np.log10(max(final_tea, 0) + eps)
                        final_lps.append(log_gap)
                    
                    self.scheduler.tell(ids, final_lps, instances=instances)

                data_memory.clear()
                total_gap+=bat_gap

            # Scheduler periodic maintenance
            if getattr(opts, 'curriculum', False) or getattr(opts, 'mab_curriculum', False):
                self.scheduler.step(epoch)
                self.scheduler.log_metrics(tb_logger, update_step)

            
            avg_gap=total_gap/epoch_len
            pbar.close()
            
            # save model
            if not opts.no_saving and (( opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or \
                                       epoch == opts.epoch_end - 1): self.save(epoch)
            
            # rollout
            if epoch%2==0:
                test_ratio=rollout(opts,self,epoch,tb_logger,tokenizer,update_step=update_step)
                test_ratio_list.append(test_ratio)
                if epoch == opts.epoch_start:
                    best_test_ratio=test_ratio
                else:
                    if best_test_ratio<test_ratio:
                        best_test_ratio=test_ratio
                        best_test_epoch=epoch
                        
            if epoch == opts.epoch_start:
                best_test_epoch=opts.epoch_start
                best_avg_gap=avg_gap
                best_gap_epoch=opts.epoch_start
            else:
                if best_avg_gap>avg_gap:
                    best_avg_gap=avg_gap
                    best_gap_epoch=epoch
            
            # log to screen
            print(f'best_test_epoch:{best_test_epoch}')
            print(f'best_test_ratio:{best_test_ratio}')
            print(f'current_avg_gap:{avg_gap}')
            print(f'best_avg_gap:{best_avg_gap}')
            print(f'best_gap_epoch:{best_gap_epoch}')

            try:
                import wandb
                if wandb.run is not None:
                    wandb_log_dict = {
                        'performance/best_test_ratio': best_test_ratio,
                        'performance/best_test_epoch': best_test_epoch,
                        'performance/avg_gap': avg_gap,
                        'performance/best_avg_gap': best_avg_gap,
                        'performance/best_gap_epoch': best_gap_epoch,
                    }
                    if hasattr(opts, 'no_wandb_step') and opts.no_wandb_step:
                        wandb.log(wandb_log_dict)
                    else:
                        wandb.log(wandb_log_dict, step=update_step)
            except ImportError:
                pass

        # close the parallel vector_env
        learning_env.close()
        teacher_env.close()
        random_env.close()
        task.close()
        return update_step

    # training for one batch 
    def train_batch(self,task:TaskForTrain,pre_step,tokenizer,pbar,epoch,tb_logger=None,instances=None,ids=None,is_replay=True):
        import time
        start_time = time.time()
        
        max_step=self.opts.max_fes//(self.opts.population_size*self.opts.skip_step)
        data_memory=Data_Memory()
        memory=Memory()

        # reset
        set_seed()
        
        # sample task for training
        if instances is None:
            instances,ids=sample_batch_task(self.opts)
        
        for i, (instance, id) in enumerate(zip(instances, ids)):
            # Handle dynamic IDs from curriculum learning
            base_id = self._infer_base_id(id)
            instance.best = self.raw_surrogate_gbest.get(base_id, 0.0)
            instance.worst = self.raw_surrogate_gworst.get(base_id, 1.0)
            
            # Ensure surrogate_gbest dict has this new ID
            if id not in self.surrogate_gbest:
                self.surrogate_gbest[id] = 0.0

        tea_pop,stu_population=task.reset(instances)

        baseline_pop=copy.deepcopy(stu_population)

        pop_feature=task.state(stu_population)
        pop_feature=torch.FloatTensor(pop_feature).to(self.opts.device)

        # record the pre_population
        pre_stu_pop=stu_population
        pre_baseline_pop=stu_population

        # record infomation
        data_memory.teacher_cost.append([p.gbest_cost for p in tea_pop])
        data_memory.stu_cost.append([p.gbest_cost for p in stu_population])
        data_memory.baseline_cost.append([p.gbest_cost for p in stu_population])

        gamma = self.opts.gamma
        eps_clip = self.opts.eps_clip
        n_step = self.opts.n_step
        k_epoch=self.opts.k_epoch

        t=0
        total_gap=0
        # for logging
        current_step=pre_step
        
        # record init cost
        init_cost=[p.gworst_cost for p in tea_pop]

        # timers
        timers = {
            'actor_inf': 0,
            'task_step': 0,
            'task_reward': 0,
            'critic_inf': 0,
            'ppo_update': 0
        }

        is_done=False
        while not is_done:
            t_s=t

            rollout_bl_val_detached = []
            rollout_bl_val = []

            while t-t_s < n_step:
                # get feature
                memory.states.append(pop_feature)

                # using model to generate expr
                _t0 = time.time()
                if self.opts.require_baseline:
                    seq,const_seq,log_prob,rand_seq,rand_c_seq,action_dict=self.actor(pop_feature,save_data=True)
                else:
                    seq,const_seq,log_prob,action_dict=self.actor(pop_feature,save_data=True)
                    rand_seq, rand_c_seq = seq, const_seq
                timers['actor_inf'] += time.time() - _t0
                
                # next_pop: population updated by expr, target_pop: teacher
                _t1 = time.time()
                target_pop,next_pop,baseline_pop,expr,is_done=task.step(stu_population,self.opts.skip_step,seq,const_seq,tokenizer,rand_seq,rand_c_seq,baseline_pop)
                timers['task_step'] += time.time() - _t1
                
                # get reward
                _t2 = time.time()
                total_reward,gap_reward,base_reward,gap=task.reward(learning_population=next_pop,target_population=target_pop,reward_method=self.opts.reward_func,
                                                                    base_reward_method=self.opts.b_reward_func,max_step=max_step,epoch=epoch,s_init_cost=init_cost,
                                                                    s_gbest=self.surrogate_gbest,pre_learning_population=pre_stu_pop,ids=ids)
                timers['task_reward'] += time.time() - _t2
                
                gap_reward=torch.FloatTensor(gap_reward).to(self.opts.device)
                base_reward=torch.FloatTensor(base_reward).to(self.opts.device)
                total_reward=torch.FloatTensor(total_reward).to(self.opts.device)

                if torch.any(torch.isnan(total_reward)):
                    print(f'gap_reward:{gap_reward},base_reward:{base_reward},total_reward:{total_reward}')
                    assert False, 'nan in reward!!'

                memory.gap_rewards.append(gap_reward)
                memory.b_rewards.append(base_reward)

                
                total_gap+=np.mean(gap)

                # critic network
                _t3 = time.time()
                baseline_val_detached,baseline_val=self.critic(pop_feature)
                timers['critic_inf'] += time.time() - _t3
                rollout_bl_val_detached.append(baseline_val_detached)
                rollout_bl_val.append(baseline_val)

                # store reward for ppo
                memory.actions.append(action_dict)
                memory.logprobs.append(log_prob)
                memory.rewards.append(total_reward)

                # store data
                data_memory.teacher_cost.append([p.gbest_cost for p in target_pop])
                data_memory.stu_cost.append([p.gbest_cost for p in next_pop])
                data_memory.gap.append(gap)
                data_memory.baseline_cost.append([p.gbest_cost for p in baseline_pop])
                data_memory.expr.append(expr)

                # next step 
                stu_population=next_pop
                pre_stu_pop=next_pop
                pre_baseline_pop=baseline_pop

                t=t+1

                # next state
                pop_feature = task.state(stu_population)
                pop_feature=torch.FloatTensor(pop_feature).to(self.opts.device)

                
                if is_done:
                    # update surrogate gbest
                    for i,id in enumerate(ids):
                        min_cost=min(data_memory.teacher_cost[-1][i],data_memory.stu_cost[-1][i])
                        self.surrogate_gbest[id]=min(self.surrogate_gbest[id],min_cost)
                    break

            t_time=t-t_s
            

            # begin updating network in PPO style
            # Gradient update strategy fix
            buffer_size = len(self.scheduler.levels) if hasattr(self, 'scheduler') else 0
            is_buffer_warm = buffer_size >= self.opts.batch_size
            
            should_update = is_replay or (not is_buffer_warm) or (np.random.rand() >= self.opts.random_stop_grad)
            
            if should_update:
                _t_up = time.time()
                old_actions = self.merge_action_dicts(memory.actions)
                old_states = torch.stack(memory.states).detach().view(-1, self.opts.fea_dim) 
                old_logprobs = torch.stack(memory.logprobs).detach().view(-1)

                # Vectorized Returns calculation
                rewards_tensor = torch.stack(memory.rewards)  # [T, B]
                with torch.no_grad():
                    next_value = self.critic(pop_feature)[0]  # [B]
                
                Reward = torch.zeros_like(rewards_tensor)
                R = next_value
                for t_inv in reversed(range(t_time)):
                    R = rewards_tensor[t_inv] + gamma * R
                    Reward[t_inv] = R
                Reward = Reward.view(-1)

                # Pre-stack rollout baseline values
                rollout_bl_val_detached_stacked = torch.stack(rollout_bl_val_detached).view(-1)
                rollout_bl_val_stacked = torch.stack(rollout_bl_val).view(-1)

                old_value = None
                for _k in range(k_epoch):
                    
                    if _k == 0:
                        logprobs = old_logprobs
                        current_bl_val_detached = rollout_bl_val_detached_stacked
                        current_bl_val = rollout_bl_val_stacked
                    else:
                        # Evaluating actions and values in batch
                        logprobs = self.actor(old_states, fix_action=old_actions)
                        current_bl_val_detached, current_bl_val = self.critic(old_states)
                    
                    logprobs = logprobs.view(-1)
                    current_bl_val_detached = current_bl_val_detached.view(-1)
                    current_bl_val = current_bl_val.view(-1)

                    # Finding the ratio (pi_theta / pi_theta__old):
                    ratios = torch.exp(logprobs - old_logprobs)

                    # Finding Surrogate Loss:
                    advantages = Reward - current_bl_val_detached

                    surr1 = ratios * advantages
                    surr2 = torch.clamp(ratios, 1-eps_clip, 1+eps_clip) * advantages
                    reinforce_loss = -torch.min(surr1, surr2).mean()

                    # define baseline loss
                    if old_value is None:
                        baseline_loss = ((current_bl_val - Reward) ** 2).mean()
                        old_value = current_bl_val.detach()
                    else:
                        vpredclipped = old_value + torch.clamp(current_bl_val - old_value, - eps_clip, eps_clip)
                        v_max = torch.max(((current_bl_val - Reward) ** 2), ((vpredclipped - Reward) ** 2))
                        baseline_loss = v_max.mean()
                    # calculate loss
                    loss = baseline_loss + reinforce_loss
                    
                    # see if loss is nan
                    if torch.isnan(loss):
                        print(f'baseline_loss:{baseline_loss}')
                        print(f'reinforce_loss:{reinforce_loss}')
                        assert False, 'nan found in loss!!'

                    # update gradient step
                    self.optimizer.zero_grad()
                    loss.backward()

                    # Clip gradient norm and get (clipped) gradient norms for logging if needed
                    grad_norms = clip_grad_norms(self.optimizer.param_groups)[0]

                    # perform gradient descent
                    self.optimizer.step()

                    current_step+=1
                    

                    # logging to tensorboard
                    if (not self.opts.no_tb) and (tb_logger is not None):
                        if current_step % self.opts.log_step == 0:
                            log_to_tb_train(tb_logger,self,Reward,grad_norms, memory.rewards,memory.gap_rewards,memory.b_rewards,gap,reinforce_loss,baseline_loss,logprobs,current_step)
                    pbar.update(1)
                timers['ppo_update'] += time.time() - _t_up
            else:
                # If skipping update, still update progress bar
                pbar.update(k_epoch)
                
            memory.clear_memory()

        # Batch end, print statistics
        total_batch_time = time.time() - start_time
        print(f"\n[Batch Timer] Total: {total_batch_time:.2f}s | "
              f"Actor: {timers['actor_inf']:.2f}s | "
              f"TaskStep: {timers['task_step']:.2f}s | "
              f"Reward: {timers['task_reward']:.2f}s | "
              f"Critic: {timers['critic_inf']:.2f}s | "
              f"Update: {timers['ppo_update']:.2f}s")
        
        # return batch step
        return current_step-pre_step,total_gap/t,data_memory
