import numpy as np
import copy
from dataset.generate_dataset import sample_batch_task, sample_batch_task_id_cec21, sample_batch_task_id

# Predefined class names for CEC21
_CEC21_CLASS_NAMES = ['Bent_cigar', 'Schwefel', 'bi_Rastrigin', 'Grie_rosen', 
                      'Hybrid1', 'Hybrid2', 'Hybrid3', 
                      'Composition1', 'Composition2', 'Composition3']

import re


def _parse_base_id(id_str: str, class_names: list) -> str:
    """
    Parse base_id from task ID string.
    
    Args:
        id_str: Task ID string, e.g. 'rand_Bent_cigar_123456' or 'mut_rand_Bent_cigar_123_1_2_999'
        class_names: Valid class name list
    
    Returns:
        Matched base_id, or 'unknown' if not found
    """
    # CEC21: match by longest substring
    sorted_names = sorted(class_names, key=len, reverse=True)
    for name in sorted_names:
        if name in id_str:
            return name
    
    return 'unknown'


def _get_mab_class_names(opts):
    """Return class names based on train_problem"""
    return _CEC21_CLASS_NAMES


class PLRScheduler:
    def __init__(self, opts, base_names=None):
        self.opts = opts
        self.capacity = getattr(opts, 'level_store_capacity', 1000)
        self.replay_prob = 0.8
        self.levels = []  # List of dicts: {'instance': inst, 'id': idx, 'lp': lp, 'stale_count': 0, 'parent_lp': None}
        self.base_names = list(base_names) if base_names is not None else None
        
        # MAB Curriculum Settings
        if getattr(opts, 'mab_curriculum', False):
            self.class_names = _get_mab_class_names(opts)
            num_classes = len(self.class_names)
            # Use logits form: initial logits are all 0, corresponding to uniform probability
            self.mab_logits = np.zeros(num_classes)
            self.mab_probs = np.ones(num_classes) / float(num_classes)
            self.mab_eta = getattr(opts, 'mab_eta', 0.5)
            self.mab_gamma = getattr(opts, 'mab_gamma', 0.95)
            self.mab_epsilon = getattr(opts, 'mab_epsilon', 0.05)
            print(f" [*] MAB Curriculum Enabled (Logits Mode). Eta: {self.mab_eta}, Gamma: {self.mab_gamma}, Epsilon: {self.mab_epsilon}, Classes: {num_classes}, Names: {self.class_names[:3]}...")

        print(" [*] Initializing PLRScheduler (Empty Buffer)")

    def ask(self, batch_size):
        """
        Sampling logic: supports standard PLR or MAB-based PLR
        """
        if not getattr(self.opts, 'mab_curriculum', False):
            # --- Standard PLR logic ---
            is_forced_random = not self.levels
            if is_forced_random or np.random.rand() > self.replay_prob:
                instances, ids = sample_batch_task(self.opts)
                ids = [f"rand_{id}_{np.random.randint(1e6)}" for id in ids]
                return instances, ids, True if is_forced_random else False
            
            lps = np.array([l['lp'] for l in self.levels])
            stales = np.array([l['stale_count'] for l in self.levels])

            if lps.max() > lps.min():
                norm_lps = (lps - lps.min()) / (lps.max() - lps.min())
            else:
                norm_lps = np.zeros_like(lps)
                
            if stales.max() > stales.min():
                norm_stales = (stales - stales.min()) / (stales.max() - stales.min())
            else:
                norm_stales = np.zeros_like(stales)

            staleness_coef = getattr(self.opts, 'staleness_coef', 0.1) 
            scores = norm_lps + staleness_coef * norm_stales

            temp = 1.0
            exp_scores = np.exp((scores - np.max(scores)) / temp)
            probs = exp_scores / exp_scores.sum()
            
            indices = np.random.choice(len(self.levels), size=batch_size, p=probs, replace=True)
            sampled_instances = [copy.deepcopy(self.levels[i]['instance']) for i in indices]
            sampled_ids = [self.levels[i]['id'] for i in indices]
            
            for i in indices:
                self.levels[i]['stale_count'] = 0
                
            return sampled_instances, sampled_ids, True
        else:
            # --- MAB-based PLR logic ---
            num_classes = len(self.class_names)
            class_indices = np.random.choice(num_classes, size=batch_size, p=self.mab_probs)
            sampled_instances = []
            sampled_ids = []
            
            # Batch-level decision for is_replay
            is_forced_random = not self.levels
            is_replay_batch = is_forced_random or (np.random.rand() < self.replay_prob)
            
            for c_idx in class_indices:
                c_name = self.class_names[c_idx]
                # Filter buffer tasks belonging to this class
                class_levels = [l for l in self.levels if l.get('base_id') == c_name]
                
                if class_levels and is_replay_batch:
                    # Sample within class using PLR
                    lps = np.array([l['lp'] for l in class_levels])
                    stales = np.array([l['stale_count'] for l in class_levels])
                    
                    if lps.max() > lps.min():
                        norm_lps = (lps - lps.min()) / (lps.max() - lps.min())
                    else:
                        norm_lps = np.zeros_like(lps)
                    
                    if stales.max() > stales.min():
                        norm_stales = (stales - stales.min()) / (stales.max() - stales.min())
                    else:
                        norm_stales = np.zeros_like(stales)
                    
                    staleness_coef = getattr(self.opts, 'staleness_coef', 0.1)
                    scores = norm_lps + staleness_coef * norm_stales
                    exp_scores = np.exp(scores - np.max(scores))
                    probs = exp_scores / exp_scores.sum()
                    
                    sel_idx = np.random.choice(len(class_levels), p=probs)
                    sel_level = class_levels[sel_idx]
                    sampled_instances.append(copy.deepcopy(sel_level['instance']))
                    sampled_ids.append(sel_level['id'])
                    # Reset global staleness
                    for l in self.levels:
                        if l['id'] == sel_level['id']:
                            l['stale_count'] = 0
                            break
                else:
                    # Sample a new instance from this class
                    insts, name = sample_batch_task_id(self.opts, c_idx + 1)
                    id_val = f"rand_{name}_{np.random.randint(1e6)}"
                    sampled_instances.append(insts[0])
                    sampled_ids.append(id_val)
            
            return sampled_instances, sampled_ids, is_replay_batch

    def tell(self, ids, lps, instances=None, parent_lps=None):
        """
        Update task status and evict excess tasks based on (LP + Staleness) score.
        Also includes Log Gap filtering to keep only tasks in [lgpc_gap_min, lgpc_gap_max].
        """
        if instances is None:
            instances = [None] * len(ids)
        if parent_lps is None:
            parent_lps = [None] * len(ids)

        # 1. Update or add tasks
        for i, (instance, id_val, lp, p_lp) in enumerate(zip(instances, ids, lps, parent_lps)):
            # Log Gap filtering
            if lp < self.opts.lgpc_gap_min or lp > self.opts.lgpc_gap_max:
                # Check if exists, remove if so
                for j in range(len(self.levels) - 1, -1, -1):
                    if self.levels[j]['id'] == id_val:
                        self.levels.pop(j)
                continue

            found = False
            for level in self.levels:
                if level['id'] == id_val:
                    # Found old task: update LP with sliding average, reset staleness
                    level['lp'] = 0.5 * level['lp'] + 0.5 * lp
                    level['stale_count'] = 0 
                    found = True
                    break
            
            if not found:
                # New task: parse source and add to buffer
                id_str = str(id_val)
                
                # Parse base_id and origin_id
                base_names = self.base_names or []
                base_id = _parse_base_id(id_str, base_names)
                
                # Find origin_id: the first rand_... part
                origin_id = id_str
                if 'rand_' in id_str:
                    start_idx = id_str.find('rand_')
                    parts = id_str[start_idx:].split('_')
                    if id_str.startswith('mut_'):
                        all_parts = id_str.split('_')
                        origin_id = "_".join(all_parts[1:-3])
                    else:
                        origin_id = id_str

                self.levels.append({
                    'instance': instance,
                    'id': id_val,
                    'base_id': base_id,
                    'origin_id': origin_id,
                    'lp': lp,
                    'stale_count': 0,
                    'parent_lp': p_lp
                })

        # 2. Check capacity and execute eviction strategy
        while len(self.levels) > self.capacity:
            
            lps_arr = np.array([l['lp'] for l in self.levels])
            stales_arr = np.array([l['stale_count'] for l in self.levels])

            # A. Normalize LP
            if lps_arr.max() > lps_arr.min():
                norm_lps = (lps_arr - lps_arr.min()) / (lps_arr.max() - lps_arr.min())
            else:
                norm_lps = np.zeros_like(lps_arr)

            # B. Normalize Staleness
            if stales_arr.max() > stales_arr.min():
                norm_stales = (stales_arr - stales_arr.min()) / (stales_arr.max() - stales_arr.min())
            else:
                norm_stales = np.zeros_like(stales_arr)

            # C. Compute retention score
            staleness_coef = getattr(self.opts, 'staleness_coef', 0.1)
            retention_scores = norm_lps + staleness_coef * norm_stales

            # 3. Evict task with lowest score
            min_idx = np.argmin(retention_scores)
            self.levels.pop(min_idx)

        # 3. MAB Logits update (linear decay + Softmax)
        if getattr(self.opts, 'mab_curriculum', False):
            num_classes = len(self.class_names)
            
            # 3.1 Parse class for each task
            task_info = []  # [(class_idx, lp), ...]
            for id_val, lp in zip(ids, lps):
                id_str = str(id_val)
                base_id = _parse_base_id(id_str, self.class_names)
                
                if base_id in self.class_names:
                    c_idx = self.class_names.index(base_id)
                    task_info.append((c_idx, lp))
            
            if task_info:
                # 3.2 Compute global rank utility
                lps_batch = np.array([t[1] for t in task_info])
                ranks = np.argsort(np.argsort(lps_batch))
                n = len(lps_batch)
                if n > 1:
                    rank_utilities = ranks / (n - 1)
                else:
                    rank_utilities = np.array([0.5])
                
                # 3.3 Aggregate rank utility by class
                class_utilities = np.zeros(num_classes)
                class_counts = np.zeros(num_classes)
                
                for i, (c_idx, _) in enumerate(task_info):
                    class_utilities[c_idx] += rank_utilities[i]
                    class_counts[c_idx] += 1
                
                # 3.4 Linear decay logits
                self.mab_logits = self.mab_gamma * self.mab_logits
                
                # 3.5 Linear update logits
                for i in range(num_classes):
                    if class_counts[i] > 0:
                        avg_rank_utility = class_utilities[i] / class_counts[i]
                        centered_utility = avg_rank_utility - 0.5
                        self.mab_logits[i] += self.mab_eta * centered_utility
            
            # Clip logits
            logit_clip = getattr(self.opts, 'mab_logit_clip', 2.0)
            self.mab_logits = np.clip(self.mab_logits, -logit_clip, logit_clip)
            
            # Softmax for probabilities
            logits_stable = self.mab_logits - np.max(self.mab_logits)
            exp_logits = np.exp(logits_stable)
            softmax_probs = exp_logits / np.sum(exp_logits)
            
            # Epsilon-greedy mixing
            epsilon = self.mab_epsilon
            self.mab_probs = (1 - num_classes * epsilon) * softmax_probs + epsilon
            
            print(f"mab_logits: {self.mab_logits}")
            print(f"mab_probs: {self.mab_probs}")

    def get_percentile_lp(self, percentile=0.8):
        """
        Get LP value at specified percentile in buffer.
        """
        if not self.levels:
            return self.opts.lgpc_gap_min
        
        lps = np.array([l['lp'] for l in self.levels])
        return np.quantile(lps, percentile)

    def log_metrics(self, tb_logger, step):
        """
        Log monitoring metrics
        """
        if not self.levels:
            return

        import matplotlib.pyplot as plt
        import os
        import pandas as pd
        
        # Prepare local save path
        analysis_dir = os.path.join(self.opts.save_dir, 'analysis')
        if not os.path.exists(analysis_dir):
            os.makedirs(analysis_dir)

        lps = [l['lp'] for l in self.levels]
        stales = [l['stale_count'] for l in self.levels]
        ids = [l['id'] for l in self.levels]
        
        # Buffer Size & Unique Count
        base_ids = [l.get('base_id', str(l['id'])) for l in self.levels]
        origin_ids = [l.get('origin_id', str(l['id'])) for l in self.levels]
        
        unique_base_count = len(set(base_ids))
        unique_lineage_count = len(set(origin_ids))
        
        summary_path = os.path.join(analysis_dir, 'buffer_summary.csv')
        df_summary = pd.DataFrame([{
            'step': step,
            'buffer_size': len(self.levels),
            'unique_base_count': unique_base_count,
            'unique_lineage_count': unique_lineage_count,
            'unique_count': unique_base_count,
            'avg_lp': np.mean(lps),
            'max_lp': np.max(lps),
            'avg_stale': np.mean(stales)
        }])
        df_summary.to_csv(summary_path, mode='a', header=not os.path.exists(summary_path), index=False)

        # Score Distribution
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.hist(lps, bins=20, color='skyblue', edgecolor='black')
        plt.title(f'LP Distribution (Step {step})')
        plt.xlabel('Log Gap')
        plt.ylabel('Count')

        plt.subplot(1, 2, 2)
        plt.hist(stales, bins=20, color='salmon', edgecolor='black')
        plt.title(f'Staleness Distribution (Step {step})')
        plt.xlabel('Stale Count')
        plt.ylabel('Count')
        plt.tight_layout()
        plt.savefig(os.path.join(analysis_dir, f'dist_step_{step}.png'))
        plt.close()

        # Task Parameter Scatter
        task_positions = []
        plot_lps = []
        for l in self.levels:
            inst = l['instance']
            pos = None
            if hasattr(inst, 'shift') and inst.shift is not None:
                pos = inst.shift
            elif hasattr(inst, 'O') and inst.O is not None:
                pos = inst.O
            
            if pos is not None:
                task_positions.append(pos[:2])
                plot_lps.append(l['lp'])
        
        if task_positions:
            task_positions = np.array(task_positions)
            plt.figure(figsize=(7, 6))
            sc = plt.scatter(task_positions[:, 0], task_positions[:, 1], 
                             alpha=0.6, c=plot_lps, cmap='viridis', edgecolors='none', s=40)
            plt.colorbar(sc, label='Log Gap (LP)')
            
            plt.axvline(x=0, color='grey', linestyle='--', alpha=0.3)
            plt.axhline(y=0, color='grey', linestyle='--', alpha=0.3)
            
            plt.title(f'Task Optima Distribution (Step {step})\n(Proj: Dim 0 vs Dim 1)')
            plt.xlabel('Shift/O Dim 0')
            plt.ylabel('Shift/O Dim 1')
            plt.grid(True, linestyle=':', alpha=0.5)
            
            margin = 10
            plt.xlim(min(-100, task_positions[:, 0].min() - margin), max(100, task_positions[:, 0].max() + margin))
            plt.ylim(min(-100, task_positions[:, 1].min() - margin), max(100, task_positions[:, 1].max() + margin))
            
            plt.savefig(os.path.join(analysis_dir, f'param_scatter_step_{step}.png'))
            plt.close()

        # Parent vs. Child Gap
        parent_child_data = []
        for l in self.levels:
            if l['parent_lp'] is not None:
                parent_child_data.append([l['parent_lp'], l['lp']])
        
        if parent_child_data:
            pc_arr = np.array(parent_child_data)
            plt.figure(figsize=(6, 6))
            plt.scatter(pc_arr[:, 0], pc_arr[:, 1], alpha=0.6, edgecolors='white')
            lims = [min(pc_arr.min(), 0), max(pc_arr.max(), 1)]
            plt.plot(lims, lims, 'r--', alpha=0.7, label='y=x')
            plt.xlabel('Parent Log Gap')
            plt.ylabel('Child Log Gap')
            plt.title(f'Evolution Quality (Step {step})')
            plt.legend()
            plt.grid(True, linestyle='--', alpha=0.6)
            plt.savefig(os.path.join(analysis_dir, f'evolution_step_{step}.png'))
            plt.close()
            
            pc_df = pd.DataFrame(pc_arr, columns=['parent_lp', 'child_lp'])
            pc_df.to_csv(os.path.join(analysis_dir, f'evolution_data_step_{step}.csv'), index=False)

        # MAB Logits (If enabled)
        if getattr(self.opts, 'mab_curriculum', False):
            mab_path = os.path.join(analysis_dir, 'mab_logits.csv')
            mab_data = {}
            num_classes = len(self.class_names)
            for i in range(num_classes):
                mab_data[f"prob_{self.class_names[i]}"] = self.mab_probs[i]
                mab_data[f"logit_{self.class_names[i]}"] = self.mab_logits[i]
            
            mab_data['step'] = step
            df_mab = pd.DataFrame([mab_data])
            df_mab.to_csv(mab_path, mode='a', header=not os.path.exists(mab_path), index=False)
            
            try:
                import wandb
                mab_log = {}
                for i in range(num_classes):
                    mab_log[f"mab/prob_{self.class_names[i]}"] = self.mab_probs[i]
                    mab_log[f"mab/logit_{self.class_names[i]}"] = self.mab_logits[i]
                
                if hasattr(self.opts, 'no_wandb_step') and self.opts.no_wandb_step:
                    wandb.log(mab_log)
                else:
                    wandb.log(mab_log, step=step)
            except ImportError:
                pass

        # Also log to wandb
        try:
            import wandb
            log_dict = {
                "plr/buffer_size": len(self.levels),
                "plr/unique_base_count": unique_base_count,
                "plr/unique_lineage_count": unique_lineage_count,
                "plr/unique_count": unique_base_count,
                "plr/lp_dist": wandb.Histogram(lps),
                "plr/staleness_dist": wandb.Histogram(stales),
            }
            if hasattr(self.opts, 'no_wandb_step') and self.opts.no_wandb_step:
                wandb.log(log_dict)
            else:
                wandb.log(log_dict, step=step)
        except ImportError:
            pass

        if tb_logger is not None:
            tb_logger.add_scalar("plr/buffer_size", len(self.levels), step)
            tb_logger.add_scalar("plr/unique_base_count", unique_base_count, step)
            tb_logger.add_scalar("plr/unique_lineage_count", unique_lineage_count, step)
            tb_logger.add_scalar("plr/unique_count", unique_base_count, step)

    def step(self, epoch):
        """
        Periodic maintenance: increase staleness
        """
        for level in self.levels:
            level['stale_count'] += 1

    def mutate_instance(self, instance, sigma=0.05):
        """
        Apply parameter-space mutation to CEC instance.
        """
        return self.mutate_instance_static((instance, sigma))

    @staticmethod
    def mutate_instance_static(args):
        """
        Static method for multiprocessing Pool.
        args: (instance, sigma)
        """
        instance, sigma = args
        new_inst = copy.deepcopy(instance)
        PLRScheduler._mutate_recursive_static(new_inst, sigma)
        return new_inst

    @staticmethod
    def _mutate_recursive_static(inst, sigma):
        """
        Recursively mutate instance and its sub-problems
        """
        def _shift_clip_bound(obj):
            if hasattr(obj, 'ub') and obj.ub is not None:
                try:
                    ub = float(obj.ub)
                    return 0.8 * ub
                except Exception:
                    pass
            return 80.0

        # 1. Mutate shift (Shift / O)
        for attr in ['shift', 'O']:
            if hasattr(inst, attr) and getattr(inst, attr) is not None:
                val = getattr(inst, attr)
                if isinstance(val, np.ndarray):
                    noise = np.random.normal(0, sigma, size=val.shape)
                    b = _shift_clip_bound(inst)
                    new_val = np.clip(val + noise, -b, b)
                    setattr(inst, attr, new_val)
                    if hasattr(inst, 'opt'):
                        inst.opt = new_val

        # 2. Mutate rotation (Rotate / M1 / M2)
        for attr in ['rotate', 'M1', 'M2']:
            if hasattr(inst, attr) and getattr(inst, attr) is not None:
                M = getattr(inst, attr)
                if isinstance(M, np.ndarray) and M.ndim == 2:
                    dim = M.shape[0]
                    noise = np.random.normal(0, sigma * 0.2, size=(dim, dim))
                    M_new = M + noise
                    try:
                        U, _, Vt = np.linalg.svd(M_new)
                        setattr(inst, attr, np.dot(U, Vt))
                    except:
                        pass

        # 3. Recursively process sub-problems
        if hasattr(inst, 'sub_problems') and inst.sub_problems is not None:
            subs = inst.sub_problems if isinstance(inst.sub_problems, (list, tuple)) else [inst.sub_problems]
            for sub in subs:
                PLRScheduler._mutate_recursive_static(sub, sigma)

    def _mutate_recursive(self, inst, sigma):
        """
        Kept for backward compatibility
        """
        self._mutate_recursive_static(inst, sigma)
