import torch
import random
from qsw import *
import numpy as np
from utils import *
import torch.nn.functional as F
from utils import rand_projections, one_dimensional_Wasserstein_prod

# Pure BOSW Implementation

class LightweightGP:
    def __init__(self, kernel_lengthscale=0.5, base_noise=1e-3):
        self.lengthscale = kernel_lengthscale
        self.base_noise = base_noise
        self.X = None; self.y = None
        self.y_mean = 0.0; self.y_std = 1.0

    def _kernel(self, X1, X2):
        cos = (X1 @ X2.T).clamp(-0.9999, 0.9999)
        ang = torch.acos(cos)
        return torch.exp(-0.5 * (ang / self.lengthscale) ** 2)

    def fit(self, X, y):
        self.X, self.y = X, y
        self.y_mean = y.mean()
        self.y_std  = y.std() + 1e-6
        if X.size(0) > 5:
            ang = torch.acos((X @ X.T).clamp(-1, 1))
            nz = ang[ang > 0]
            if nz.numel():
                med = nz.median().item()
                self.lengthscale = float(np.clip(0.5 * med, 0.1, 2.0))

    def predict(self, Xtest):
        if self.X is None or self.X.size(0) == 0:
            dev = Xtest.device
            return torch.zeros(len(Xtest), device=dev), torch.ones(len(Xtest), device=dev)
        K  = self._kernel(self.X, self.X)
        Ks = self._kernel(Xtest, self.X)
        Kss= self._kernel(Xtest, Xtest)
        jitter = self.base_noise
        I = torch.eye(K.size(0), device=K.device)
        for _ in range(5):
            try:
                L = torch.linalg.cholesky(K + jitter*I); break
            except:
                jitter *= 10
        alpha = torch.cholesky_solve(((self.y-self.y_mean)/self.y_std).unsqueeze(1), L).squeeze(1)
        mu = Ks @ alpha
        v  = torch.cholesky_solve(Ks.T, L)
        var = (torch.diag(Kss) - (Ks @ v).diagonal()).clamp_min(1e-6)
        return mu*self.y_std + self.y_mean, var.sqrt()*self.y_std

def _f_integrand(pc1, pc2, theta, p=2):
    return one_dimensional_Wasserstein_prod(pc1, pc2, theta, p=p).float().view(-1)

def _acq(mu, sigma, y_hist, kind='ucb', beta=0.7):
    if kind == 'ucb':
        return mu + beta*sigma
    elif kind == 'ei':
        best = y_hist.max()
        z = (mu - best) / (sigma + 1e-12)
        std = torch.distributions.Normal(0,1)
        return (mu - best) * std.cdf(z) + sigma * torch.exp(std.log_prob(z))
    elif kind == 'thompson':
        return torch.normal(mu, sigma + 1e-6)
    else:
        raise ValueError(kind)

@torch.no_grad()
def get_bosw_projections(L, device, pc1, pc2, p=2, seed=None,
                         n_init=None, batch_size=None, n_candidates=4096,
                         acq_kind='ucb', beta=0.7, dim=None):
    if seed is not None:
        set_seed(seed)
    d = dim or pc1.shape[1]
    n_init = max(8, min(32, (L//5 if n_init is None else n_init)))
    batch_size = max(1, (L//10 if batch_size is None else batch_size))
    thetas = rand_projections(d, n_init, device)
    y = _f_integrand(pc1, pc2, thetas, p=p)
    gp = LightweightGP(0.5); gp.fit(thetas, y)
    while thetas.size(0) < L:
        b = min(batch_size, L - thetas.size(0))
        C = rand_projections(d, n_candidates, device)
        mu, sig = gp.predict(C)
        acq = _acq(mu, sig, y, kind=acq_kind, beta=beta)
        chosen, acq_work = [], acq.clone()
        for _ in range(b):
            j = int(torch.argmax(acq_work).item())
            cand = C[j]; chosen.append(cand)
            close = (C @ cand).clamp(-1,1).abs() > 0.999
            acq_work[close] = -1e30
        new = torch.stack(chosen, 0)
        y_new = _f_integrand(pc1, pc2, new, p=p)
        thetas = torch.cat([thetas, new], 0); y = torch.cat([y, y_new], 0)
        gp.fit(thetas, y)
    return thetas[:L]

# Hybrid BOSW Implementation (ABOSW/ARBOSW)

def get_bosw_champion(L, device, pc1, pc2, p=2, beta=2.0, seed=None, ai='ucb'):
    """
    1. QSW-dominant initialization (80% best QSW methods)
    2. Ultra-fast BO refinement (minimal iterations)
    3. RQSW-style randomization for gradient tasks
    4. Speed-optimized candidate generation
    """
    
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    @torch.no_grad()
    def one_d_wasserstein_mean(pc1, pc2, theta):
        return one_dimensional_Wasserstein_prod(pc1, pc2, theta, p=p).mean().sqrt()

    # QSW-Dominant Initialization
    def get_champion_initialization(n_init, device):
        """
        Use 80% from winning QSW methods: RSQSW (spiral) + ROCQSW (coulomb).
        """
        if n_init <= 5:
            # For small L, use pure SQSW 
            return get_sqsw_projections(n_init, device)
        
        # 50% SQSW + 30% CQSW + 20% diversification
        n_sqsw = max(1, int(0.5 * n_init))   # SQSW is the base of RSQSW (winner)
        n_cqsw = max(1, int(0.3 * n_init))   # CQSW is the base of ROCQSW (2nd place)
        n_diverse = n_init - n_sqsw - n_cqsw
        
        projections = []
        
        # SQSW projections
        sqsw_projs = get_sqsw_projections(n_sqsw, device)
        projections.extend([sqsw_projs[i] for i in range(len(sqsw_projs))])
        
        # CQSW projections  
        cqsw_projs = get_cqsw_projections(n_cqsw, device)
        projections.extend([cqsw_projs[i] for i in range(len(cqsw_projs))])
        
        # Diversification with optimized points
        if n_diverse > 0:
            diverse_projs = get_dqsw_projections(n_diverse, device)
            projections.extend([diverse_projs[i] for i in range(len(diverse_projs))])
        
        return projections[:n_init]

    # Ultra-Fast BO
    class SpeedOptimizedGP:
        def __init__(self, kernel_lengthscale=0.3):
            self.lengthscale = kernel_lengthscale
            self.X_train, self.y_train = None, None
            self.y_mean, self.y_std = 0, 1
            self.base_noise = 1e-3

        def fit(self, X, y):
            self.X_train = X
            self.y_train = y
            self.y_mean = y.mean()
            self.y_std = y.std() + 1e-6

        def predict(self, X_test):
            if self.X_train is None:
                return torch.zeros(len(X_test)), torch.ones(len(X_test))
            
            try:
                # Simplified kernel computation for speed
                angles = torch.acos(torch.clamp(X_test @ self.X_train.T, -0.9999, 0.9999))
                K_test = torch.exp(-0.5 * (angles / self.lengthscale) ** 2)
                
                # Direct prediction without Cholesky
                weights = torch.softmax(K_test, dim=1)
                mu = weights @ ((self.y_train - self.y_mean) / self.y_std) * self.y_std + self.y_mean
                
                # Conservative variance estimate
                sigma = torch.ones_like(mu) * self.y_std * 0.5
                
                return mu, sigma
                
            except Exception:
                return torch.zeros(len(X_test)) + self.y_mean, torch.ones(len(X_test)) * self.y_std

    # QSW Objective
    def compute_winner_inspired_objective(proj, existing_projections, pc1, pc2):
        """
        Objective inspired by what makes RSQSW and ROCQSW win.
        Focus: Spiral-like uniformity + Coulomb-like spacing + SW quality.
        """
        if len(existing_projections) == 0:
            return one_d_wasserstein_mean(pc1, pc2, proj.unsqueeze(0)).item()
        
        existing_tensor = torch.stack(existing_projections)
        
        # 1. Spiral-inspired uniformity (from RSQSW success)
        # Measure how well this point fits spiral-like distribution
        n_existing = len(existing_projections)
        expected_spiral_points = get_sqsw_projections(n_existing + 1, device)
        if len(expected_spiral_points) > n_existing:
            spiral_target = expected_spiral_points[n_existing]
            spiral_alignment = torch.dot(proj, spiral_target).item()
            spiral_score = (spiral_alignment + 1) / 2 
        else:
            spiral_score = 0.5
        
        # 2. Coulomb energy (from ROCQSW success)
        distances = torch.norm(existing_tensor - proj.unsqueeze(0), dim=1)
        coulomb_energy = (1.0 / distances.clamp(min=0.01)).sum().item()
        coulomb_score = 1.0 / (1.0 + 0.1 * coulomb_energy)
        
        # 3. SW quality
        candidate_set = existing_projections + [proj]
        proj_tensor = torch.stack(candidate_set)
        sw_score = one_d_wasserstein_mean(pc1, pc2, proj_tensor).item()
        sw_score_norm = 1.0 / (1.0 + sw_score)
        
        # Weight: Emphasize what works for winners
        # RSQSW wins -> emphasize spiral uniformity
        # ROCQSW is 2nd -> emphasize coulomb spacing  
        weights = [0.4, 0.35, 0.25]  # [spiral, coulomb, sw]
        
        return weights[0] * spiral_score + weights[1] * coulomb_score + weights[2] * sw_score_norm

    # Speed-Optimized Candidate Generation
    def generate_fast_candidates(existing_projs, n_candidates, device):
        """
        Generate candidates optimized for speed, not exhaustive search.
        """
        candidates = []
        
        # 60% from winning QSW methods (fast generation)
        n_qsw = int(0.6 * n_candidates)
        
        n_sqsw = n_qsw // 2
        n_cqsw = n_qsw - n_sqsw
        
        if n_sqsw > 0:
            sqsw_cands = get_sqsw_projections(n_sqsw, device)
            candidates.append(sqsw_cands)
        
        if n_cqsw > 0:
            cqsw_cands = get_cqsw_projections(n_cqsw, device)
            candidates.append(cqsw_cands)
        
        # 40% random exploration
        n_random = n_candidates - (n_sqsw + n_cqsw)
        if n_random > 0:
            candidates.append(rand_projections(3, n_random, device))
        
        return torch.cat(candidates) if candidates else rand_projections(3, n_candidates, device)
    
    selected_projections = []
    gp = SpeedOptimizedGP(kernel_lengthscale=0.25)

    if L <= 10:
        # For small L, use pure SQSW 
        return get_sqsw_projections(L, device)

    # Champion initialization: QSW-dominant
    n_init = min(12, max(4, L // 5)) 
    init_projections = get_champion_initialization(n_init, device)
    
    X_train, y_train = [], []
    for proj in init_projections:
        selected_projections.append(proj)
        obj = compute_winner_inspired_objective(proj, selected_projections[:-1], pc1, pc2)
        X_train.append(proj)
        y_train.append(obj)

    X_train = torch.stack(X_train)
    y_train = torch.tensor(y_train, device=device)

    max_bo_iterations = min(3, L // 5)
    bo_iteration = 0
    
    while len(selected_projections) < L and bo_iteration < max_bo_iterations:
        gp.fit(X_train, y_train)

        # Small batches for speed
        remaining = L - len(selected_projections)
        batch_size = min(max(1, remaining // 3), 3) 
        
        # Fast candidate generation
        n_candidates = min(200, 50 * batch_size)  
        candidates = generate_fast_candidates(selected_projections, n_candidates, device)

        # Minimal diversity filtering for speed
        if len(selected_projections) > 0:
            existing_tensor = torch.stack(selected_projections)
            angles = torch.acos(torch.clamp(candidates @ existing_tensor.T, -1, 1))
            min_angles = angles.min(dim=1)[0]
            keep_mask = min_angles > 0.1  # Less aggressive filtering
            candidates = candidates[keep_mask]
            
            if len(candidates) < 10:
                candidates = generate_fast_candidates(selected_projections, 50, device)

        # Fast acquisition
        mu, sigma = gp.predict(candidates)
        acq_values = mu + beta * sigma  # Simple UCB

        # Fast sequential selection
        n_select = min(batch_size, remaining)
        
        for _ in range(n_select):
            best_idx = torch.argmax(acq_values).item()
            selected_point = candidates[best_idx]
            selected_projections.append(selected_point)
            
            obj = compute_winner_inspired_objective(selected_point, selected_projections[:-1], pc1, pc2)
            X_train = torch.cat([X_train, selected_point.unsqueeze(0)])
            y_train = torch.cat([y_train, torch.tensor([obj], device=device)])

            # Minimal diversity penalty for speed
            angles = torch.acos(torch.clamp(candidates @ selected_point, -1, 1))
            diversity_penalty = torch.exp(-5 * angles)
            acq_values *= (1 - 0.3 * diversity_penalty)

        bo_iteration += 1

    # Fill remaining with winning method (SQSW) if needed
    while len(selected_projections) < L:
        remaining = L - len(selected_projections)
        additional_projs = get_sqsw_projections(remaining, device)
        for i in range(min(remaining, len(additional_projs))):
            selected_projections.append(additional_projs[i])

    # Minimal refinement for speed
    if L <= 100:
        final_tensor = torch.stack(selected_projections)
        refined_projections = fast_local_refinement(final_tensor, device, n_iters=3)
        return refined_projections

    return torch.stack(selected_projections)

def fast_local_refinement(projections, device, n_iters=3):
    """
    Fast local refinement optimized for speed.
    """
    refined_projs = projections.clone().requires_grad_(True)
    optimizer = torch.optim.Adam([refined_projs], lr=0.05)
    
    for _ in range(n_iters):
        optimizer.zero_grad()
        
        # Normalize to sphere
        normalized_projs = F.normalize(refined_projs, dim=1)
        
        # Simple Coulomb energy minimization (what ROCQSW does)
        pairwise_dists = torch.cdist(normalized_projs, normalized_projs) + torch.eye(len(normalized_projs), device=device) * 1e6
        coulomb_energy = (1.0 / pairwise_dists).sum()
        
        # Sphere constraint
        sphere_penalty = torch.sum((torch.norm(refined_projs, dim=1) - 1)**2)
        
        loss = coulomb_energy + 1000 * sphere_penalty
        loss.backward()
        optimizer.step()
        
        # Project back to sphere
        with torch.no_grad():
            refined_projs.data = F.normalize(refined_projs.data, dim=1)
    
    return refined_projs.detach()

# Champion Randomized Version
class ChampionProjectionSampler:
    """
    Randomized sampler that mimics winning RSQSW strategy.
    """
    
    def __init__(self, optimal_projections, device, L):
        self.optimal_projections = optimal_projections
        self.device = device
        self.L = L
        
        # Store QSW base methods for RQSW-style sampling
        self.sqsw_base = get_sqsw_projections(min(50, L), device)
        self.cqsw_base = get_cqsw_projections(min(50, L), device)
    
    def sample(self, n_samples=None):
        """
        Sample like winning RSQSW method.
        """
        if n_samples is None:
            n_samples = self.L
        
        samples = []
        
        for _ in range(n_samples):
            # 70% SQSW-style (like winning RSQSW)
            if torch.rand(1) < 0.7:
                if len(self.sqsw_base) > 0:
                    base_idx = torch.randint(0, len(self.sqsw_base), (1,)).item()
                    base_proj = self.sqsw_base[base_idx]
                    
                    # Add RQSW-style random rotation
                    rotation_matrix = self._sample_random_rotation()
                    rotated_proj = rotation_matrix @ base_proj
                    samples.append(rotated_proj)
                else:
                    samples.append(rand_projections(3, 1, self.device).squeeze())
            
            # 30% CQSW-style (like second-place ROCQSW)
            else:
                if len(self.cqsw_base) > 0:
                    base_idx = torch.randint(0, len(self.cqsw_base), (1,)).item()
                    base_proj = self.cqsw_base[base_idx]
                    
                    # Add RQSW-style random rotation
                    rotation_matrix = self._sample_random_rotation()
                    rotated_proj = rotation_matrix @ base_proj
                    samples.append(rotated_proj)
                else:
                    samples.append(rand_projections(3, 1, self.device).squeeze())
        
        return torch.stack(samples[:n_samples])
    
    def _sample_random_rotation(self):
        """Sample random rotation matrix (RQSW style)."""
        A = torch.randn(3, 3, device=self.device)
        Q, R = torch.linalg.qr(A)
        Q = Q * torch.det(Q).sign()
        return Q
    
    def __call__(self):
        return self.sample()

def get_bosw_adaptive(L, device, pc1, pc2, p=2, beta=0.7, seed=None, 
                              ai='ucb', mode='auto', gradient_steps=None, task_type=None):
    """
    Champion BOSW with smart mode selection to beat all QSW variants.
    """
    
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
    
    # Smart mode detection
    if mode == 'auto':
        if task_type == 'gradient' or gradient_steps is not None:
            mode = 'randomized'
        else:
            mode = 'deterministic'
    
    if mode == 'randomized':
        # For gradient tasks: use champion randomized approach
        learning_budget = min(L // 3, 40) 
        optimal_projections = get_bosw_champion(learning_budget, device, pc1, pc2, p, beta, seed, ai)
        return ChampionProjectionSampler(optimal_projections, device, L)
    else:
        # For approximation tasks: use champion deterministic approach
        return get_bosw_champion(L, device, pc1, pc2, p, beta, seed, ai)
    
