import numpy as np
import torch
import ray
import os
import sys
import time
import copy
from tqdm import tqdm
from collections import deque, defaultdict
from utils.utils import set_seed
from utils.logger import log_to_test_with_teacher, gen_overall_tab
from pbo_env import L2E_env, MadDE, sep_CMA_ES, PSO, DE
from dataset.generate_dataset import sample_batch_task_id_cec21

# ==========================================
# 1. Ray Actor Definition (handles environment simulation)
# ==========================================

@ray.remote
class RayRolloutWorker:
    def __init__(self, worker_id, opts):
        self.worker_id = worker_id
        self.opts = opts
        # Initialize independent environment instances
        self.student_env = L2E_env(dim=opts.dim, ps=opts.population_size, problem=None, 
                                   max_x=opts.max_x, min_x=opts.min_x, max_fes=opts.max_fes, 
                                   boarder_method=opts.boarder_method)
        self.baseline_env = L2E_env(dim=opts.dim, ps=opts.population_size, problem=None, 
                                    max_x=opts.max_x, min_x=opts.min_x, max_fes=opts.max_fes, 
                                    boarder_method=opts.boarder_method)
        
        # Initialize Teacher
        if opts.teacher == 'madde':
            self.teacher_env = MadDE(dim=opts.dim, max_fes=opts.max_fes, problem=None, min_x=opts.min_x, max_x=opts.max_x)
        elif opts.teacher == 'cmaes':
            self.teacher_env = sep_CMA_ES(dim=opts.dim, problem=None, max_x=opts.max_x, min_x=opts.min_x, max_fes=opts.max_fes, sigma=opts.cmaes_sigma)
        elif opts.teacher == 'pso':
            self.teacher_env = PSO(ps=opts.population_size, dim=opts.dim, max_fes=opts.max_fes, min_x=opts.min_x, max_x=opts.max_x, pho=0.2)
        elif opts.teacher == 'de':
            self.teacher_env = DE(dim=opts.dim, ps=opts.population_size, min_x=opts.min_x, max_x=opts.max_x, max_fes=opts.max_fes)
        else:
            raise NotImplementedError(f"Teacher {opts.teacher} not supported")

        # Maintain population states
        self.student_pop = None
        self.baseline_pop = None
        self.teacher_pop = None

    def reset(self, task):
        """Reset environment and return initial state"""
        instance = task['instance']
        # Use fixed seed based on global_id for reproducibility
        seed = 9999 + task['global_id']
        set_seed(seed)
        
        # Check if problem dimension matches current environment dimension
        problem_dim = getattr(instance, 'dim', self.opts.dim)
        if problem_dim != self.student_env.dim:
            # Recreate environments to match problem dimension
            self.student_env = L2E_env(dim=problem_dim, ps=self.opts.population_size, problem=None, 
                                       max_x=self.opts.max_x, min_x=self.opts.min_x, max_fes=self.opts.max_fes, 
                                       boarder_method=self.opts.boarder_method)
            self.baseline_env = L2E_env(dim=problem_dim, ps=self.opts.population_size, problem=None, 
                                        max_x=self.opts.max_x, min_x=self.opts.min_x, max_fes=self.opts.max_fes, 
                                        boarder_method=self.opts.boarder_method)
            
            # Reinitialize Teacher
            if self.opts.teacher == 'madde':
                self.teacher_env = MadDE(dim=problem_dim, max_fes=self.opts.max_fes, problem=None, min_x=self.opts.min_x, max_x=self.opts.max_x)
            elif self.opts.teacher == 'cmaes':
                self.teacher_env = sep_CMA_ES(dim=problem_dim, problem=None, max_x=self.opts.max_x, min_x=self.opts.min_x, max_fes=self.opts.max_fes, sigma=self.opts.cmaes_sigma)
            elif self.opts.teacher == 'pso':
                self.teacher_env = PSO(ps=self.opts.population_size, dim=problem_dim, max_fes=self.opts.max_fes, min_x=self.opts.min_x, max_x=self.opts.max_x, pho=0.2)
            elif self.opts.teacher == 'de':
                self.teacher_env = DE(dim=problem_dim, ps=self.opts.population_size, min_x=self.opts.min_x, max_x=self.opts.max_x, max_fes=self.opts.max_fes)
        
        self.student_env.problem = copy.deepcopy(instance)
        self.baseline_env.problem = copy.deepcopy(instance)
        self.teacher_env.problem = copy.deepcopy(instance)

        self.student_pop = self.student_env.reset()
        self.baseline_pop = self.baseline_env.reset()
        self.teacher_pop = self.teacher_env.reset()

        return self._get_payload(is_done=False)

    def step(self, expr_stu, expr_base, skip_step):
        """Execute simulation step"""
        # 1. Student step
        self.student_pop, _, done, _ = self.student_env.step({
            'base_population': self.student_pop, 
            'expr': expr_stu, 
            'skip_step': skip_step
        })
        # 2. Baseline step
        self.baseline_pop, _, _, _ = self.baseline_env.step({
            'base_population': self.baseline_pop, 
            'expr': expr_base, 
            'skip_step': skip_step
        })
        # 3. Teacher step
        if self.opts.tea_step == 'step':
            ts = skip_step
            if self.opts.teacher == 'glpso' and ts != 1:
                ts = ts // 2
            t_action = {'skip_step': ts}
        else:
            t_action = {'fes': skip_step * self.opts.population_size}
        
        self.teacher_pop, _, _, _ = self.teacher_env.step(t_action)

        return self._get_payload(is_done=done)

    def _get_payload(self, is_done):
        fea_mode = self.opts.fea_mode if hasattr(self.opts, 'fea_mode') else 'full'
        obs = self.student_pop.feature_encoding(fea_mode) 
        
        def get_pop_stats(pop):
            costs = pop.c_cost
            return {
                'min': np.min(costs),
                'max': np.max(costs),
                'mean': np.mean(costs),
                'median': np.median(costs),
                'std': np.std(costs),
                'gbest': pop.gbest_cost
            }

        return {
            'obs': obs,
            'stu_stats': get_pop_stats(self.student_pop),
            'base_stats': get_pop_stats(self.baseline_pop),
            'tea_stats': get_pop_stats(self.teacher_pop),
            'stu_cost': self.student_pop.gbest_cost,
            'base_cost': self.baseline_pop.gbest_cost,
            'tea_cost': self.teacher_pop.gbest_cost,
            'done': is_done
        }

# ==========================================
# 2. Helper Functions
# ==========================================

def decode_and_parse(seq, const_seq, tokenizer):
    from expr.expression import get_prefix_with_consts, prefix_to_infix
    pre, c_pre = get_prefix_with_consts(seq, const_seq, 0)
    str_expr = [tokenizer.decode(p) for p in pre]
    success, infix = prefix_to_infix(str_expr, c_pre, tokenizer)
    return success, infix

# ==========================================
# 3. Main Rollout Function
# ==========================================

def rollout(opts, agent, epoch, tb_logger, tokenizer, update_step=None, testing=False):
    if not ray.is_initialized():
        ray.init(ignore_reinit_error=True)

    agent.set_evaling()
    need_log = (epoch % 5 == 0 or testing)
    
    # 1. Prepare all test tasks
    task_queue = deque()
    total_tasks = 0
    
    # CEC21 test
    test_ids = range(1, 11)
    for bat_id, p_id in enumerate(test_ids):
        # Using a problem-specific seed for fair comparison across runs
        instances, p_name = sample_batch_task_id_cec21(dim=opts.dim, batch_size=opts.batch_size, problem_id=p_id, seed=999+p_id)
        for inst in instances:
            task_queue.append({
                'instance': inst,
                'p_id': p_id,
                'p_name': p_name,
                'bat_id': bat_id,
                'global_id': total_tasks
            })
            total_tasks += 1

    # 2. Initialize Workers
    num_workers = 64
    workers = [RayRolloutWorker.remote(i, opts) for i in range(num_workers)]
    
    worker_states = [None] * num_workers
    completed_results = []
    future_to_worker = {}
    obs_buffer = []

    pbar = tqdm(total=total_tasks, desc=f"Epoch {epoch} Async Rollout")

    # Launch initial Reset tasks
    for i in range(num_workers):
        if task_queue:
            task = task_queue.popleft()
            worker_states[i] = {
                'meta': task,
                'stu_costs': [], 'tea_costs': [], 'base_costs': [],
                'stu_stats_history': [], 'tea_stats_history': [], 'base_stats_history': []
            }
            ref = workers[i].reset.remote(task)
            future_to_worker[ref] = i

    # 3. Async pipeline loop
    while future_to_worker or obs_buffer:
        # A. Check completed CPU simulations
        timeout = 0 if len(obs_buffer) >= 8 else 0.005
        ready_refs, _ = ray.wait(list(future_to_worker.keys()), num_returns=len(future_to_worker), timeout=timeout)
        
        for ref in ready_refs:
            w_idx = future_to_worker.pop(ref)
            try:
                res = ray.get(ref)
            except Exception as e:
                print(f"Worker {w_idx} crashed: {e}")
                continue
            
            # Record cost sequences
            worker_states[w_idx]['stu_costs'].append(res['stu_cost'])
            worker_states[w_idx]['tea_costs'].append(res['tea_cost'])
            worker_states[w_idx]['base_costs'].append(res['base_cost'])
            
            # Record detailed statistics
            worker_states[w_idx]['stu_stats_history'].append(res['stu_stats'])
            worker_states[w_idx]['tea_stats_history'].append(res['tea_stats'])
            worker_states[w_idx]['base_stats_history'].append(res['base_stats'])
            
            if res['done']:
                # Task finished, record and try to assign new task
                completed_results.append(worker_states[w_idx])
                pbar.update(1)
                if task_queue:
                    new_task = task_queue.popleft()
                    worker_states[w_idx] = {
                        'meta': new_task, 
                        'stu_costs': [], 'tea_costs': [], 'base_costs': [],
                        'stu_stats_history': [], 'tea_stats_history': [], 'base_stats_history': []
                    }
                    new_ref = workers[w_idx].reset.remote(new_task)
                    future_to_worker[new_ref] = w_idx
                else:
                    worker_states[w_idx] = None
            else:
                # Task not finished, feature enters inference buffer
                obs_buffer.append((w_idx, res['obs']))

        # B. Batch GPU inference
        if obs_buffer:
            batch_indices = [x[0] for x in obs_buffer]
            batch_obs = np.array([x[1] for x in obs_buffer])
            obs_tensor = torch.FloatTensor(batch_obs).to(opts.device)
            
            with torch.no_grad():
                if opts.require_baseline:
                    seq, const_seq, _, rand_seq, rand_c_seq = agent.actor(obs_tensor)
                else:
                    seq, const_seq, _ = agent.actor(obs_tensor)
                    rand_seq, rand_c_seq = seq, const_seq

            # C. Decode and dispatch next Step
            for i, w_idx in enumerate(batch_indices):
                succ, expr_stu = decode_and_parse(seq[i], const_seq[i], tokenizer)
                expr_stu = expr_stu if succ else "x"
                
                succ_r, expr_base = decode_and_parse(rand_seq[i], rand_c_seq[i], tokenizer)
                expr_base = expr_base if succ_r else "x"
                
                step_ref = workers[w_idx].step.remote(expr_stu, expr_base, opts.skip_step)
                future_to_worker[step_ref] = w_idx
            
            obs_buffer = []

    pbar.close()

    # 4. Post-processing and statistics
    if not completed_results:
        print("Error: No rollout tasks completed successfully.")
        return 0.0

    collect_dict = {}
    total_outperform = 0
    grouped = defaultdict(list)
    for res in completed_results:
        p_id = res['meta']['p_id']
        grouped[p_id].append(res)
        if res['stu_costs'][-1] < res['base_costs'][-1]:
            total_outperform += 1

    path = os.path.join(opts.data_saving_dir, f'epoch_{epoch}', 'test')
    if not os.path.exists(path): os.makedirs(path)

    # Save detailed metrics CSV
    import pandas as pd
    detailed_dir = os.path.join(path, 'detailed_stats')
    if not os.path.exists(detailed_dir): os.makedirs(detailed_dir)
    
    for res in completed_results:
        p_name = res['meta']['p_name']
        g_id = res['meta']['global_id']
        for alg_key, alg_full in [('stu', 'student'), ('tea', 'teacher'), ('base', 'baseline')]:
            history = res[f'{alg_key}_stats_history']
            df_history = pd.DataFrame(history)
            csv_name = f"{p_name}_id{g_id}_{alg_full}.csv"
            df_history.to_csv(os.path.join(detailed_dir, csv_name), index_label='step')

    for p_id in sorted(grouped.keys()):
        group = grouped[p_id]
        p_name = group[0]['meta']['p_name']
        
        final_stu = [g['stu_costs'][-1] for g in group]
        final_tea = [g['tea_costs'][-1] for g in group]
        final_base = [g['base_costs'][-1] for g in group]

        collect_dict[p_name] = {
            'teacher': {'mean': np.mean(final_tea), 'std': np.std(final_tea)},
            'random_model': {'mean': np.mean(final_base), 'std': np.std(final_base)},
            'student': {'mean': np.mean(final_stu), 'std': np.std(final_stu)}
        }

        if need_log:
            # Data format conversion
            log_stu = np.array([g['stu_costs'] for g in group]).T.tolist()
            log_tea = np.array([g['tea_costs'] for g in group]).T.tolist()
            log_base = np.array([g['base_costs'] for g in group]).T.tolist()
            log_to_test_with_teacher(log_tea, log_base, log_stu, epoch, 0, p_id, 
                                     tb_logger.file_writer.get_logdir() if tb_logger else opts.log_dir, logged=True)

        if not opts.no_tb and tb_logger:
            tb_logger.add_scalars(f'performance/cost/{p_name}', 
                                  {'student': np.mean(final_stu), 'baseline': np.mean(final_base), 'teacher': np.mean(final_tea)}, 
                                  epoch)
        
        try:
            import wandb
            if wandb.run is not None:
                wandb_log_dict = {
                    f'performance/cost/{p_name}/student': np.mean(final_stu),
                    f'performance/cost/{p_name}/baseline': np.mean(final_base),
                    f'performance/cost/{p_name}/teacher': np.mean(final_tea),
                }
                if hasattr(opts, 'no_wandb_step') and opts.no_wandb_step:
                    wandb.log(wandb_log_dict)
                else:
                    wandb.log(wandb_log_dict, step=update_step)
        except ImportError:
            pass

    gen_overall_tab(collect_dict, path)
    
    test_outperform_ratio = total_outperform / total_tasks
    if not opts.no_tb and tb_logger:
        tb_logger.add_scalar('performance/test_outperform_ratio', test_outperform_ratio, epoch)

    try:
        import wandb
        if wandb.run is not None:
            wandb_log_dict = {'performance/test_outperform_ratio': test_outperform_ratio}
            if hasattr(opts, 'no_wandb_step') and opts.no_wandb_step:
                wandb.log(wandb_log_dict)
            else:
                wandb.log(wandb_log_dict, step=update_step)
    except ImportError:
        pass

    return test_outperform_ratio
