
import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time
import uuid
from typing import List, Dict
from tqdm import tqdm

# ==========================================
# 0. A100 Configuration
# ==========================================
if torch.cuda.is_available():
    device_count = torch.cuda.device_count()
    if device_count > 3:
        DEVICE = torch.device("cuda:2")
    else:
        DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cpu")

print(f"Exhaustive Model Search on {DEVICE}")
torch.set_default_dtype(torch.float32)

CONFIG = {
    'LR': 0.05,
    'EPOCHS': 300,
    'RESTARTS': 100,       
    'BATCH_SIZE': 200000, 
    'EPSILON': 1e-8,
}

# ==========================================
# 1. Generate all possible model combinations
# ==========================================
def generate_exhaustive_universe():
    """Generate all possible cognitive parameter combinations"""
    
    # Mutually exclusive options: Learning mechanisms
    learning_opts = ['None', 'Sym', 'Asym']
    
    # Independent switches: Cognitive components
    components = ['Bias', 'Inert', 'Perc']
    
    universe = []
    
    # Iterate through all learning mechanisms
    for learn_type in learning_opts:
        # Iterate through all component combinations (0000 to 1111)
        for i in range(1 << len(components)):
            # Bitwise check if a component is enabled
            has_bias = (i >> 0) & 1
            has_inert = (i >> 1) & 1
            has_perc = (i >> 2) & 1

            # Build parameter list
            params = ['beta']  # Beta is always required
            
            # Learning parameters
            if learn_type == 'Sym':
                params.append('alpha')
            elif learn_type == 'Asym':
                params.extend(['alpha_pos', 'alpha_neg'])
            # 'None' type does not have alpha parameters
            
            if has_bias: params.append('theta')
            if has_inert: params.append('lambda')
            if has_perc: params.append('R_perc')
            
            # Model naming
            name_parts = ['RL', learn_type]
            if has_bias: name_parts.append('Bias')
            if has_inert: name_parts.append('Inert')
            if has_perc: name_parts.append('Perc')
            
            model_name = "_".join(name_parts)
            
            universe.append({
                'name': model_name,
                'params': params,
                'meta': {
                    'learning': learn_type,
                    'has_bias': bool(has_bias),
                    'has_inert': bool(has_inert),
                    'has_perc': bool(has_perc),
                }
            })
            
    return universe

RL_MODEL_UNIVERSE = generate_exhaustive_universe()
STATIC_MODEL_NAMES = ['M0_Bias', 'M1_WSLS']

print(f"Generated {len(RL_MODEL_UNIVERSE)} RL Combinations + 2 Static Models.")

# ==========================================
# 2. Core computation engine (Mega Parallel Kernel)
# ==========================================
class ExhaustiveParallelRW(nn.Module):
    def __init__(self, batch_size: int, restarts: int):
        super().__init__()
        self.B = batch_size
        self.R = restarts
        self.configs = RL_MODEL_UNIVERSE
        self.M = len(self.configs)
        
        # Initialize all possible parameters (B, R, M)
        # Even if a model does not use a parameter, we initialize it (masked), so GPU parallelization is faster
        all_params = ['beta', 'alpha_pos', 'alpha_neg', 'theta', 'lambda', 'R_perc']
        self.raw_params = nn.ParameterDict()
        for p in all_params:
            self.raw_params[p] = nn.Parameter(torch.randn(self.B, self.R, self.M, device=DEVICE) * 0.1)

        # === Build masks ===
        self.register_buffer('mask_theta', torch.zeros(self.M, device=DEVICE))
        self.register_buffer('mask_lambda', torch.zeros(self.M, device=DEVICE))
        self.register_buffer('mask_r_perc', torch.zeros(self.M, device=DEVICE))
        
        # Learning rate control logic
        self.register_buffer('use_sym_alpha', torch.zeros(self.M, device=DEVICE))  # Use symmetric alpha
        self.register_buffer('use_no_alpha', torch.zeros(self.M, device=DEVICE))  # Force alpha=0

        for i, cfg in enumerate(self.configs):
            p = cfg['params']
            meta = cfg['meta']
            
            if meta['has_bias']: self.mask_theta[i] = 1.0
            if meta['has_inert']: self.mask_lambda[i] = 1.0
            if meta['has_perc']: self.mask_r_perc[i] = 1.0
            
            if meta['learning'] == 'Sym':
                self.use_sym_alpha[i] = 1.0
            elif meta['learning'] == 'None':
                self.use_no_alpha[i] = 1.0
            
        # Broadcasting dimensions (M) -> (1, 1, M)
        for attr in ['mask_theta', 'mask_lambda', 'mask_r_perc', 'use_sym_alpha', 'use_no_alpha']:
            tensor = getattr(self, attr)
            setattr(self, attr, tensor.view(1, 1, -1))

    def get_constrained_params(self):
        beta = nn.functional.softplus(self.raw_params['beta'])
        
        # Alpha logic
        raw_pos = torch.sigmoid(self.raw_params['alpha_pos'])
        raw_neg = torch.sigmoid(self.raw_params['alpha_neg'])
        
        # Symmetric logic: If the model is Sym, alpha_neg = alpha_pos
        a_neg = self.use_sym_alpha * raw_pos + (1 - self.use_sym_alpha) * raw_neg
        a_pos = raw_pos
        
        # Zero logic: If the model is None, alpha = 0
        a_pos = a_pos * (1 - self.use_no_alpha)
        a_neg = a_neg * (1 - self.use_no_alpha)
        
        # Component parameters
        theta = self.raw_params['theta'] * self.mask_theta
        lam = torch.tanh(self.raw_params['lambda']) * self.mask_lambda
        
        rp_opt = nn.functional.softplus(self.raw_params['R_perc'])
        r_perc = (1.0 * (1 - self.mask_r_perc)) + (rp_opt * self.mask_r_perc)
        
        return beta, a_pos, a_neg, theta, lam, r_perc

    def forward(self, actions, rewards, is_ling, forgone, prev_acts):
        # 1. Expand Data
        actions = actions.view(self.B, 1, 1, -1)
        rewards = rewards.view(self.B, 1, 1, -1)
        is_ling = is_ling.view(self.B, 1, 1, -1)
        forgone = forgone.view(self.B, 1, 1, -1)
        prev_acts = prev_acts.view(self.B, 1, 1, -1)
        
        # 2. Params
        beta, a_pos, a_neg, theta, lam, r_perc = [
            p.unsqueeze(-1) for p in self.get_constrained_params()
        ]
        
        # 3. Init State
        Q = torch.zeros(self.B, self.R, self.M, 2, device=DEVICE)
        total_nll = torch.zeros(self.B, self.R, self.M, device=DEVICE)
        seq_len = actions.shape[3]
        
        # 4. Time Loop
        for t in range(seq_len):
            curr_act = actions[..., t].long().unsqueeze(-1)
            curr_rew = rewards[..., t].unsqueeze(-1)
            curr_ling = is_ling[..., t].unsqueeze(-1)
            curr_prev = prev_acts[..., t].long().unsqueeze(-1)
            curr_forgone = forgone[..., t].unsqueeze(-1)
            
            # --- Decision ---
            logits = Q + 0.0 
            
            # Bias
            bias_add = torch.zeros_like(logits)
            bias_add[..., 1] = theta.squeeze(-1)
            logits = logits + bias_add
            
            # Inertia
            valid_prev = (curr_prev != -1)
            safe_prev = torch.clamp(curr_prev, 0, 1)
            target_idx = safe_prev.expand(self.B, self.R, self.M, 1)
            inertia_add = torch.zeros_like(logits).scatter(3, target_idx, lam)
            logits = logits + inertia_add * valid_prev.float()
            
            # Softmax
            scaled = logits * beta
            log_probs = torch.log_softmax(scaled + CONFIG['EPSILON'], dim=3)  # Added epsilon to avoid log(0)
            
            chosen_lp = log_probs.gather(3, act_idx := curr_act.expand(self.B, self.R, self.M, 1))
            total_nll = total_nll - chosen_lp.squeeze(-1)
            
            # --- Learning ---
            is_amp = (curr_ling > 0.5) & (curr_rew > 0.5)
            eff_r = torch.where(is_amp, r_perc, curr_rew)
            
            q_chosen = Q.gather(3, act_idx)
            pe = eff_r - q_chosen
            lr = torch.where(pe >= 0, a_pos, a_neg)
            
            delta = torch.zeros_like(Q).scatter(3, act_idx, lr * pe)
            Q = Q + delta
            
        return total_nll

class MegaStaticModels(nn.Module):
    def __init__(self, batch_size, restarts):
        super().__init__()
        self.theta = nn.Parameter(torch.randn(batch_size, restarts, 1, device=DEVICE) * 0.1)
        self.epsilon = nn.Parameter(torch.randn(batch_size, restarts, 1, device=DEVICE) * 0.1)
    
    def get_physical_params(self):
        return self.theta * 5.0, torch.sigmoid(self.epsilon) * 0.5

    def forward(self, actions, rewards, prev_acts):
        actions = actions.unsqueeze(1)
        p_theta, p_eps = self.get_physical_params()
        
        # Model 0
        p_c = torch.sigmoid(p_theta).expand(-1, -1, actions.shape[2])
        prob = torch.where(actions == 1, p_c, 1.0 - p_c)
        nll_m0 = -torch.log(prob + CONFIG['EPSILON']).sum(dim=2) 
        
        # Model 1
        prev_rew = torch.roll(rewards, 1, 1)
        prev_rew[:, 0] = 0
        target = torch.where(prev_rew.unsqueeze(1) > 0.5, prev_acts.unsqueeze(1), 1 - prev_acts.unsqueeze(1))
        is_match = (actions == target)
        prob_step = torch.where(is_match, 1.0 - p_eps, p_eps)
        prob_step = torch.where(prev_acts.unsqueeze(1) == -1, torch.tensor(0.5, device=DEVICE), prob_step)
        nll_m1 = -torch.log(prob_step + CONFIG['EPSILON']).sum(dim=2)
        
        return nll_m0, nll_m1


# ==========================================
# 3. Data processing
# ==========================================
class GlobalDataset:
    def __init__(self, data_dir):
        files = glob.glob(os.path.join(data_dir, "**", "*.csv"), recursive=True)
        if not files: files = glob.glob(os.path.join(data_dir, "*", "*", "*.csv"))
        
        self.data = []
        self.meta = []
        self.group_ids = []
        self.group_map = {
            'Baseline': 0, 'Optimism': 1, 'Punishment': 2, 'Stimulus': 3, 
            'Magnitude': 4, 'Authority': 5, 'Threat': 6, 'Sycophancy': 7, 'Regret': 8
        }
        
        print(f"Loading {len(files)} files...")
        for f in tqdm(files):
            try:
                df = pd.read_csv(f)
                if len(df) < 10: continue
                act = df['action'].apply(lambda x: 1 if str(x).strip() == 'Compliance' else 0).values.astype(np.float32)
                rew = df['reward'].values.astype(np.float32)
                forgone = df['forgone_reward'].fillna(0.0).values.astype(np.float32) if 'forgone_reward' in df.columns else np.zeros_like(rew)
                group = str(df['group'].iloc[0]) if 'group' in df.columns else 'Unknown'
                if group not in self.group_map: continue
                
                is_ling = 1.0 if group in ['Stimulus', 'Magnitude', 'Optimism', 'Punishment', 'Authority', 'Threat', 'Regret', 'Baseline'] else 0.0
                prev = np.concatenate([[-1], act[:-1]])
                L = 50
                mat = np.zeros((L, 5), dtype=np.float32)
                cur = min(len(act), L)
                mat[:cur] = np.stack([act[:cur], rew[:cur], np.full(cur, is_ling), forgone[:cur], prev[:cur]], axis=1)
                
                self.data.append(mat)
                self.meta.append({'file': os.path.basename(f), 'group': group, 'agent': df['model'].iloc[0] if 'model' in df.columns else 'Unknown'})
                self.group_ids.append(self.group_map[group])
            except: continue
            
        self.full_tensor = torch.tensor(np.array(self.data), device=DEVICE)
        self.group_tensor = torch.tensor(self.group_ids, device=DEVICE)

    def get_validity_mask(self):
        """
        Dynamic validity mask: Only allow valid parameter combinations
        """
        N = len(self.data)
        M_RL = len(RL_MODEL_UNIVERSE)
        total_models = 2 + M_RL
        mask = torch.zeros(N, total_models, dtype=torch.bool, device=DEVICE)
        
        # M0, M1 valid everywhere
        mask[:, :2] = True
        
        g = self.group_tensor
        
        # RL Models Validity
        for i, cfg in enumerate(RL_MODEL_UNIVERSE):
            col = i + 2
            meta = cfg['meta']
            
            # Start Valid
            is_valid = torch.ones(N, dtype=torch.bool, device=DEVICE)
            
            # Rule 1: Perception only in Stimulus/Mag
            # if meta['has_perc']:
            #     is_valid &= ((g == 3) | (g == 4))
                
            # Rule 3: No Learning (alpha=0) is valid everywhere
            mask[:, col] = is_valid
            
        return mask

# ==========================================
# 4. Main Execution Logic
# ==========================================
def main():
    data_path = r"YOUR_DATA_PATH"
    dataset = GlobalDataset(data_path)
    
    full_data = dataset.full_tensor
    b_act = full_data[:,:,0]
    b_rew = full_data[:,:,1]
    b_ling = full_data[:,:,2]
    b_prev = full_data[:,:,4]
    
    N = full_data.shape[0]
    R = CONFIG['RESTARTS']
    
    print(f"\nStarting Exhaustive Search (Agents={N}, Models={2 + len(RL_MODEL_UNIVERSE)})...")
    
    # 1. Static
    m_static = MegaStaticModels(N, R).to(DEVICE)
    opt_static = optim.Adam(m_static.parameters(), lr=0.1)
    for _ in range(100):
        opt_static.zero_grad()
        n0, n1 = m_static(b_act, b_rew, b_prev)
        (n0.sum() + n1.sum()).backward()
        opt_static.step()
    with torch.no_grad():
        n0, n1 = m_static(b_act, b_rew, b_prev)
        bn0, idx0 = n0.min(dim=1); bn1, idx1 = n1.min(dim=1)
        pt, pe = m_static.get_physical_params()
        best_theta = pt.gather(1, idx0.view(N, 1, 1)).squeeze()
        best_eps = pe.gather(1, idx1.view(N, 1, 1)).squeeze()

    # 2. RL
    m_rl = ExhaustiveParallelRW(N, R).to(DEVICE)
    opt_rl = optim.Adam(m_rl.parameters(), lr=CONFIG['LR'])
    
    for _ in tqdm(range(CONFIG['EPOCHS']), desc="RL Training"):
        opt_rl.zero_grad()
        nll = m_rl(b_act, b_rew, b_ling, b_prev, b_prev)
        nll.sum().backward()
        opt_rl.step()
        
    with torch.no_grad():
        final_nll = m_rl(b_act, b_rew, b_ling, b_prev, b_prev)
        best_nll_rl, best_idx = final_nll.min(dim=1)
        phys = m_rl.get_constrained_params()
        
        # Extract best params per model per file
        idx_exp = best_idx.unsqueeze(1)  # (N, 1, M)
        best_params_rl = {}
        keys = ['beta', 'alpha_pos', 'alpha_neg', 'theta', 'lambda', 'R_perc']
        for k, v in zip(keys, phys):
            best_params_rl[k] = v.gather(1, idx_exp).squeeze(1)

    # 3. BIC
    k_static = torch.tensor([1, 1], device=DEVICE)
    k_rl_list = [len(m['params']) for m in RL_MODEL_UNIVERSE]
    k_rl = torch.tensor(k_rl_list, device=DEVICE)
    
    n_obs = 50
    bic_static = 2 * torch.stack([bn0, bn1], dim=1) + k_static * np.log(n_obs)
    bic_rl = 2 * best_nll_rl + k_rl * np.log(n_obs)
    
    all_bics = torch.cat([bic_static, bic_rl], dim=1)
    all_nlls = torch.cat([torch.stack([bn0, bn1], dim=1), best_nll_rl], dim=1)
    
    valid_mask = dataset.get_validity_mask()
    masked_bics = torch.where(valid_mask, all_bics, torch.tensor(float('inf'), device=DEVICE))
    min_bic, winner_idx = torch.min(masked_bics, dim=1)
    
    # 4. Export
    model_names = STATIC_MODEL_NAMES + [m['name'] for m in RL_MODEL_UNIVERSE]
    
    w_idx = winner_idx.cpu().numpy()
    m_bic = min_bic.cpu().numpy()
    valid_cpu = valid_mask.cpu().numpy()
    nll_cpu = all_nlls.cpu().numpy()
    bic_cpu = all_bics.cpu().numpy()
    
    th_cpu = best_theta.cpu().numpy()
    ep_cpu = best_eps.cpu().numpy()
    rl_p_cpu = {k: v.cpu().numpy() for k, v in best_params_rl.items()}
    
    results = []
    print("Generating Report...")
    for i in tqdm(range(N)):
        meta = dataset.meta[i]
        wid = w_idx[i]
        
        for m_idx in range(len(model_names)):
            if not valid_cpu[i, m_idx]: continue
            
            row = {
                **meta,
                'model': model_names[m_idx],
                'nll': float(nll_cpu[i, m_idx]),
                'bic': float(bic_cpu[i, m_idx]),
                'is_winner': (m_idx == wid)
            }
            
            if m_idx == 0: row['theta'] = float(th_cpu[i])
            elif m_idx == 1: row['epsilon'] = float(ep_cpu[i])
            else:
                rl_idx = m_idx - 2
                for k, v in rl_p_cpu.items():
                    row[k] = float(v[i, rl_idx])
            results.append(row)
            
    df = pd.DataFrame(results)
    
    # Impute Defaults
    if 'alpha' in df.columns:
        pass
    
    defaults = {
        'theta': 0.0, 'lambda': 0.0, 'R_perc': 1.0, 'epsilon': 0.0, 'beta': 0.0,
        'alpha_pos': 0.0, 'alpha_neg': 0.0
    }
    
    df.fillna(defaults, inplace=True)
    
    print("Saving Report...")
    import time
    timestamp = str(time.time())
    report_fp = f"logs/fit_analysis/ExhaustiveResults_{timestamp}_{str(uuid.uuid4())}.csv"
    df.to_csv(report_fp, index=False)

# Main execution
if __name__ == "__main__":
    main()
