from __future__ import annotations

import torch
import torch.nn.functional as F
from constraints.constraint import *
from constraints.primal import *

import math
from typing import List, Tuple, Dict


from torch import Tensor
from typing import Callable, List, Tuple, Optional



############# BETA SHEET (TOY EXAMPLE) #############



def projection_beta(target_representation, primal_func, start_index, lr=0.1, stride=4, sections=16, verbose=True, skip_alternate=True):
    """
    Incrementally optimize both angle and distance constraints for offsets, with option to skip every other offset.
    Improved version with better convergence properties.
    """
    # Initialize projection variables
    nan_mask = target_representation.isnan()
    x0 = torch.nan_to_num(target_representation.clone().detach())
    
    # Clone and detach xi
    xi = x0.clone().detach()
    xi.requires_grad = False
    
    # We'll optimize all parts after start_index
    trainable_start = 16  # Assuming the first 16 elements are fixed
    xi_trainable = xi[trainable_start:].clone().detach()
    xi_trainable.requires_grad = True
    
    # Initialize projection hyperparams
    max_outer_iterations = 500  # Increased
    constraint_tolerance = 1e-5
    max_mu_penalty = 1e4  # Increased
    max_inner_loops = 50  # Increased
    
    # Extended tolerance schedules
    tolerance_schedule_a = [0.0, 0.0, 0.0]
    tolerance_schedule_d = [0.0, 0.0, 0.0]
    
    # Create learning rate schedule
    lr_base = lr
    lr_schedule = [lr_base * (0.9 ** i) for i in range(10)]  # Enough for 10 offsets
    
    # For each offset, we'll progressively add constraints
    max_offset = 24 - start_index  # Based on your loop in the original code
    
    # Create a list of offsets to process, skipping every other one if requested
    if skip_alternate:
        offsets_to_process = list(range(0, max_offset, 2))
    else:
        offsets_to_process = list(range(max_offset))
    
    if verbose:
        print(f"Processing offsets: {offsets_to_process}")
    
    # Initialize optimizer once
    optimizer = torch.optim.Adam([xi_trainable], lr=lr_base)
    
    for i, current_offset in enumerate(offsets_to_process):
        if verbose:
            print(f"Optimizing from offset {current_offset}")
        
        # Reset lambda and penalty parameters for each offset
        lambda_init = 0.0
        mu_penalty_init = 1.0 + (i * 0.5)  # Gradually increase initial penalty
        lambda_lagrange = torch.tensor(lambda_init, device=x0.device, dtype=torch.float32)
        mu_penalty = torch.tensor(mu_penalty_init, device=x0.device, dtype=torch.float32)
        
        # Update learning rate based on schedule
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_schedule[min(i, len(lr_schedule)-1)]
        
        # Calculate constraint weights
        weight_factors = [max(0.9 ** (i - offset_idx), 0.5) for offset_idx in range(i+1)]
        
        # Dinamically adjust inner loops based on constraint complexity
        inner_loop_factor = min(1 + i/3, 2.5)
        max_inner_loops_dynamic = int(max_inner_loops * inner_loop_factor)
        
        # Iterate across penalty for this set of constraints
        for outer_iter in range(max_outer_iterations):
            # Calculate relaxation factor that increases with iterations if needed
            relaxation_factor = 1.0
            if outer_iter > max_outer_iterations // 2:
                relaxation_factor = 1.0 + ((outer_iter - max_outer_iterations // 2) / 
                                          (max_outer_iterations // 2)) * 0.5
            
            for iteration in range(max_inner_loops_dynamic):
                optimizer.zero_grad()
                
                # Reconstruct xi with fixed and trainable parts
                xi_combined = torch.cat((xi[:trainable_start], xi_trainable))
                
                # Calculate total constraint violation for all offsets up to the current one
                constraint_violation_total = torch.tensor(0.0, device=x0.device) 
                individual_violations = []
                
                # Apply both angle and distance constraints for all processed offsets
                for offset_idx in range(i + 1):
                    offset = offsets_to_process[offset_idx]
                    weight = weight_factors[offset_idx]
                    
                    # Get appropriate tolerance with relaxation
                    angle_tol = tolerance_schedule_a[min(offset_idx, len(tolerance_schedule_a)-1)] * relaxation_factor
                    dist_tol = tolerance_schedule_d[min(offset_idx, len(tolerance_schedule_d)-1)] * relaxation_factor
                    
                    # Add angle penalty
                    angle_violation = angle_penalty_with_offset(
                        xi_combined, offset, start_index, tolerance=angle_tol
                    )
                    constraint_violation_total += weight * angle_violation
                    
                    # Add distance penalty
                    dist_violation = distance_penalty_with_offset(
                        xi_combined, offset, start_index, tolerance=dist_tol
                    )
                    constraint_violation_total += weight * dist_violation
                    
                    individual_violations.append((angle_violation, dist_violation))
                
                # Compute loss
                solution_distance = 0.01 * primal_func(x0, xi_trainable, num_fixed=trainable_start, num_optimized=len(xi_trainable))
                
                # Add a small regularization term to keep solutions stable
                regularization = 0.01 * torch.sum(xi_trainable ** 2)
                
                penalty = (lambda_lagrange * constraint_violation_total) + (0.5 * mu_penalty * constraint_violation_total**2)
                total_loss = solution_distance + penalty + regularization
                total_loss.backward()
                optimizer.step()
                
                # Early stopping if all constraints up to this offset are satisfied
                if constraint_violation_total < constraint_tolerance:
                    if verbose:
                        print(f"  Offset {current_offset} converged after {iteration+1} iterations")
                    break
                    
            # Update Lagrange Multiplier and Penalty Parameter
            with torch.no_grad():
                lambda_lagrange += mu_penalty * constraint_violation_total
                if abs(constraint_violation_total) > constraint_tolerance:
                    mu_penalty = min(mu_penalty * 1.5, max_mu_penalty)  # Slower increase
                else:
                    if verbose:
                        print(f"  All constraints up to offset {current_offset} satisfied")
                    break
            
            # Debug output for challenging convergence
            if verbose and (outer_iter + 1) % 50 == 0:
                print(f"  Outer iter {outer_iter+1}, constraint violation: {constraint_violation_total.item()}")
                for idx, (ang_viol, dist_viol) in enumerate(individual_violations):
                    offset = offsets_to_process[idx]
                    print(f"    Offset {offset}: Angle violation={ang_viol.item():.6f}, Distance violation={dist_viol.item():.6f}")
        
        # Verification step: check all constraints up to this offset
        if verbose:
            with torch.no_grad():
                xi_verify = torch.cat((xi[:trainable_start], xi_trainable))
                for offset_idx in range(i + 1):
                    offset = offsets_to_process[offset_idx]
                    angle_violation = angle_penalty_with_offset(xi_verify, offset, start_index)
                    dist_violation = distance_penalty_with_offset(xi_verify, offset, start_index)
                    print(f"  Verification - Offset {offset}: Angle: {angle_violation.item()}, Distance: {dist_violation.item()}")
    
    # Final update of the tensor
    xi[trainable_start:] = xi_trainable.clone().detach()
    xi[nan_mask] = float("nan")
    
    return xi



####################################################

################### PDZ SETTING ####################


def projection_pdz(
    target_rep: Tensor,                    # (R,4,3) with NaNs allowed
    p_chain: Tensor, 
    primal_func: Callable[[Tensor, Tensor], Tensor],
    gap_pairs_ra: torch.LongTensor,        # (K,4) from helper above
    start_index: int = 94,
    *,
    first_movable_res: int = 0,             
    last_movable_res: Optional[int] = 5,  # New parameter - if None, uses all residues from first_movable_res onwards
    n_constraints: int = 3,
    hot_start: Optional[Tensor] = None,
    reference: Optional[Tensor] = None,
    gap_penalty: bool = False,
    tol: float = 1e-4,
    max_outer: int = 25,
    max_inner: int = 300,
    mu_init: float = 1.0,
    mu_growth: float = 10.0,
    lr: float = 1e-2,
    verbose: bool = False,
) -> Tuple[Tensor, Tuple[List[float], List[float], float]]:
    """
    Augmented-Lagrangian projector that keeps the geometry in (R,4,3)
    format yet remains compatible with the original penalty helpers.
    
    Args:
        last_movable_res: Optional end index (exclusive) for trainable residues.
                         If None, all residues from first_movable_res onwards are trainable.
    """
    # return target_rep, None
    # Temporarily ignore this constraint
    # angle_penalty_pdz = lambda w, x, y, z: torch.tensor(0.0)
    
    if gap_penalty:
        complete_geometry_penalty = lambda x, y: differentiable_break_penalty(x, reduction='sum')
    else:
        complete_geometry_penalty = lambda x, y: torch.tensor(0.0)

    
    # Formulate number of constraints
    offsets = list(range(n_constraints))
    n_off = len(offsets)

    dev, dt = target_rep.device, target_rep.dtype
    R, A, _ = target_rep.shape             # A==4

    # ------------ initial state -------------------------------------
    x0          = torch.nan_to_num(target_rep.clone())
    reference   = reference if reference is not None else x0.clone()
    xi          = hot_start.clone().detach() if hot_start is not None else x0.clone().detach()
    fixed_mask  = torch.isnan(target_rep)

    # trainable slice (res ≥ first_movable_res and < last_movable_res if specified)
    if last_movable_res is None:
        xi_tr = xi[first_movable_res:].clone().detach().requires_grad_(True)
    else:
        xi_tr = xi[first_movable_res:last_movable_res].clone().detach().requires_grad_(True)
    
    opt   = torch.optim.Adam([xi_tr], lr=lr)

    # lambda / mu initialisation
    lam_a = torch.zeros(n_off, device=dev, dtype=dt)
    lam_d = torch.zeros(n_off, device=dev, dtype=dt)
    mu_a  = torch.full((n_off,), mu_init, device=dev, dtype=dt)
    mu_d  = torch.full((n_off,), mu_init, device=dev, dtype=dt)
    lam_g = torch.tensor(0.0, device=dev, dtype=dt)
    mu_g  = torch.tensor(mu_init, device=dev, dtype=dt)
    
    def to_flat_coords(t: Tensor) -> Tensor:
        """(R,4,3) →  (R*4*3,) coordinate vector"""
        return t.reshape(-1)

    # --------------- optimisation loop ------------------------------
    for outer in range(max_outer):
        for inner in range(max_inner):
            # Reconstruct full tensor with trainable slice updated
            if last_movable_res is None:
                xi_full = torch.cat([xi[:first_movable_res], xi_tr], dim=0).nan_to_num()
            else:
                xi_full = torch.cat([
                    xi[:first_movable_res], 
                    xi_tr, 
                    xi[last_movable_res:]
                ], dim=0).nan_to_num()
                        
            c_a = torch.stack([angle_penalty_pdz   (xi_full, p_chain, k, start_index) for k in offsets])
            c_d = torch.stack([distance_penalty_pdz(xi_full, p_chain, k, start_index) for k in offsets])
            c_g = complete_geometry_penalty(xi_full, gap_pairs_ra)

            lag = (
                primal_func(to_flat_coords(reference), to_flat_coords(xi_full))
                + torch.dot(lam_a, c_a) + 0.5 * torch.dot(mu_a, c_a**2)
                + torch.dot(lam_d, c_d) + 0.5 * torch.dot(mu_d, c_d**2)
                + lam_g * c_g           + 0.5 * mu_g * c_g**2
            )
            
            opt.zero_grad()
            lag.backward()
            opt.step()

            if torch.max(torch.abs(torch.cat([c_a, c_d, c_g.view(1)]))) < tol:
                break  # inner

        # ---- multipliers / penalties update ------------------------
        with torch.no_grad():
            # Reconstruct full tensor with trainable slice updated
            if last_movable_res is None:
                xi_full = torch.cat([xi[:first_movable_res], xi_tr], dim=0).nan_to_num()
            else:
                xi_full = torch.cat([
                    xi[:first_movable_res], 
                    xi_tr, 
                    xi[last_movable_res:]
                ], dim=0).nan_to_num()
                
            c_a = torch.stack([angle_penalty_pdz   (xi_full, p_chain, k, start_index) for k in offsets])
            c_d = torch.stack([distance_penalty_pdz(xi_full, p_chain, k, start_index) for k in offsets])
            c_g = complete_geometry_penalty(xi_full, gap_pairs_ra)

            lam_a += mu_a * c_a
            lam_d += mu_d * c_d
            lam_g += mu_g * c_g

            mu_a = torch.where(c_a.abs() > tol, mu_a * mu_growth, mu_a)
            mu_d = torch.where(c_d.abs() > tol, mu_d * mu_growth, mu_d)
            
            if abs(c_g) > tol:
                mu_g *= mu_growth
                
            if verbose:
                errs = ", ".join(
                    f"{o}:({ca.item():.2e},{cd.item():.2e})"
                    for o, ca, cd in zip(offsets, c_a, c_d)
                )
                print(f"[outer {outer:02d}]  {errs}  gap:{c_g.item():.2e}")

            if torch.max(torch.abs(torch.cat([c_a, c_d, c_g.view(1)]))) < tol:
                break  # outer

    # -------- write back / restore NaNs -----------------------------
    if last_movable_res is None:
        xi[first_movable_res:] = xi_tr.detach()
    else:
        xi[first_movable_res:last_movable_res] = xi_tr.detach()
    
    # xi[fixed_mask]         = float("nan")
    
    return xi.detach().nan_to_num(), (
        [abs(v.item()) for v in c_a],
        [abs(v.item()) for v in c_d],
        abs(c_g.item()),
    )





def projection_pdz_beta(
    target_rep: Tensor,                    # (R,4,3) with NaNs allowed
    p_chain: Tensor, 
    primal_func: Callable[[Tensor, Tensor], Tensor],
    gap_pairs_ra: torch.LongTensor,        # (K,4) from helper above
    start_index: int = 94,
    *,
    first_movable_res: int = 0,             
    last_movable_res: Optional[int] = 5,  # New parameter - if None, uses all residues from first_movable_res onwards
    n_constraints: int = 3,
    hot_start: Optional[Tensor] = None,
    reference: Optional[Tensor] = None,
    gap_penalty: bool = False,
    tol: float = 1e-4,
    max_outer: int = 25,
    max_inner: int = 300,
    mu_init: float = 1.0,
    mu_growth: float = 10.0,
    lr: float = 1e-2,
    verbose: bool = False,
) -> Tuple[Tensor, Tuple[List[float], List[float], float]]:
    """
    Augmented-Lagrangian projector that keeps the geometry in (R,4,3)
    format yet remains compatible with the original penalty helpers.
    
    Args:
        last_movable_res: Optional end index (exclusive) for trainable residues.
                         If None, all residues from first_movable_res onwards are trainable.
    """
    # return target_rep, None
    # Temporarily ignore this constraint
    # angle_penalty_pdz = lambda w, x, y, z: torch.tensor(0.0)
    
    if gap_penalty:
        complete_geometry_penalty = lambda x, y: differentiable_break_penalty(x, reduction='sum')
    else:
        complete_geometry_penalty = lambda x, y: torch.tensor(0.0)

    
    # Formulate number of constraints
    offsets = list(range(n_constraints))
    n_off = len(offsets)

    dev, dt = target_rep.device, target_rep.dtype
    R, A, _ = target_rep.shape             # A==4

    # ------------ initial state -------------------------------------
    x0          = torch.nan_to_num(target_rep.clone())
    reference   = reference if reference is not None else x0.clone()
    xi          = hot_start.clone().detach() if hot_start is not None else x0.clone().detach()
    fixed_mask  = torch.isnan(target_rep)

    # trainable slice (res ≥ first_movable_res and < last_movable_res if specified)
    if last_movable_res is None:
        xi_tr = xi[first_movable_res:].clone().detach().requires_grad_(True)
    else:
        xi_tr = xi[first_movable_res:last_movable_res].clone().detach().requires_grad_(True)
    
    opt   = torch.optim.Adam([xi_tr], lr=lr)

    # lambda / mu initialisation
    lam_a = torch.zeros(n_off, device=dev, dtype=dt)
    lam_d = torch.zeros(n_off, device=dev, dtype=dt)
    mu_a  = torch.full((n_off,), mu_init, device=dev, dtype=dt)
    mu_d  = torch.full((n_off,), mu_init, device=dev, dtype=dt)
    lam_g = torch.tensor(0.0, device=dev, dtype=dt)
    mu_g  = torch.tensor(mu_init, device=dev, dtype=dt)
    lam_s = torch.tensor(0.0, device=dev, dtype=dt)
    mu_s  = torch.tensor(mu_init, device=dev, dtype=dt)
    
    def to_flat_coords(t: Tensor) -> Tensor:
        """(R,4,3) →  (R*4*3,) coordinate vector"""
        return t.reshape(-1)

    # --------------- optimisation loop ------------------------------
    for outer in range(max_outer):
        for inner in range(max_inner):
            # Reconstruct full tensor with trainable slice updated
            if last_movable_res is None:
                xi_full = torch.cat([xi[:first_movable_res], xi_tr], dim=0).nan_to_num()
            else:
                xi_full = torch.cat([
                    xi[:first_movable_res], 
                    xi_tr, 
                    xi[last_movable_res:]
                ], dim=0).nan_to_num()
                        
            c_a = torch.stack([angle_penalty_pdz   (xi_full, p_chain, k, start_index) for k in offsets])
            c_d = torch.stack([distance_penalty_pdz(xi_full, p_chain, k, start_index) for k in offsets])
            c_g = complete_geometry_penalty(xi_full, gap_pairs_ra)
            
            c_s = beta_sheet_penalty(
                xi_full,
                window=(90, 100),
                min_len=5,
                tau_energy=0.15,
                tau_min=0.25,
                max_shift=1,
            )

            lag = (
                primal_func(to_flat_coords(reference), to_flat_coords(xi_full))
                + torch.dot(lam_a, c_a) + 0.5 * torch.dot(mu_a, c_a**2)
                + torch.dot(lam_d, c_d) + 0.5 * torch.dot(mu_d, c_d**2)
                + lam_g * c_g           + 0.5 * mu_g * c_g**2
                + lam_s * c_s           + 0.5 * mu_s * c_s**2
            )
            
            opt.zero_grad()
            lag.backward()
            opt.step()

            if torch.max(torch.abs(torch.cat([c_a, c_d, c_g.view(1), c_s.view(1)]))) < tol:
                break  # inner

        # ---- multipliers / penalties update ------------------------
        with torch.no_grad():
            # Reconstruct full tensor with trainable slice updated
            if last_movable_res is None:
                xi_full = torch.cat([xi[:first_movable_res], xi_tr], dim=0).nan_to_num()
            else:
                xi_full = torch.cat([
                    xi[:first_movable_res], 
                    xi_tr, 
                    xi[last_movable_res:]
                ], dim=0).nan_to_num()
                
            c_a = torch.stack([angle_penalty_pdz   (xi_full, p_chain, k, start_index) for k in offsets])
            c_d = torch.stack([distance_penalty_pdz(xi_full, p_chain, k, start_index) for k in offsets])
            c_g = complete_geometry_penalty(xi_full, gap_pairs_ra)
            
            c_s = beta_sheet_penalty(
                xi_full,
                window=(90, 100),
                min_len=5,
                tau_energy=0.15,
                tau_min=0.25,
                max_shift=1,
            )

            lam_a += mu_a * c_a
            lam_d += mu_d * c_d
            lam_g += mu_g * c_g
            lam_s += mu_s * c_s

            mu_a = torch.where(c_a.abs() > tol, mu_a * mu_growth, mu_a)
            mu_d = torch.where(c_d.abs() > tol, mu_d * mu_growth, mu_d)
            mu_s = torch.where(c_s.abs() > tol, mu_s * mu_growth, mu_s)
            
            if abs(c_g) > tol:
                mu_g *= mu_growth
                
            if verbose:
                errs = ", ".join(
                    f"{o}:({ca.item():.2e},{cd.item():.2e})"
                    for o, ca, cd in zip(offsets, c_a, c_d)
                )
                print(f"[outer {outer:02d}]  {errs}  beta:{c_s.item():.2e}")

            if torch.max(torch.abs(torch.cat([c_a, c_d, c_g.view(1), c_s.view(1)]))) < tol:
                break  # outer

    # -------- write back / restore NaNs -----------------------------
    if last_movable_res is None:
        xi[first_movable_res:] = xi_tr.detach()
    else:
        xi[first_movable_res:last_movable_res] = xi_tr.detach()
        
    return xi.detach().nan_to_num(), (
        [abs(v.item()) for v in c_a],
        [abs(v.item()) for v in c_d],
        abs(c_g.item()),
    )


