import re
import os
import json
import math
import torch
import typing
import argparse
import numpy as np
from torch import nn
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset


# ----------- GPU availability check -----------
def check_gpu_availability():
    """Check for available GPUs and print their names."""
    gpu_count = torch.cuda.device_count()
    if gpu_count > 0:
        print(f"Found {gpu_count} GPU(s) on this machine.")
        for i in range(gpu_count):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("No GPU found on this machine.")

        
# --- Pareto Frontier Calculations ---
def is_pareto_front(recon_losses, sparsity_losses, threshold=50):
    num_points = len(recon_losses)
    is_pareto = np.ones(num_points, dtype=bool)
    for i in range(num_points):
        for j in range(num_points):
            if (recon_losses[j] <= recon_losses[i]) and (sparsity_losses[j] <= sparsity_losses[i]) and \
               ((recon_losses[j] < recon_losses[i]) or (sparsity_losses[j] < sparsity_losses[i])):
                is_pareto[i] = False
                break
    
    pareto_indices = np.where(is_pareto)[0]
    print(f"[Pareto Filter] Initial Pareto point count: {len(pareto_indices)}")
    
    if len(pareto_indices) < 2:
        print(f"[Pareto Filter] Not enough Pareto points to filter. Returning all.")
        return is_pareto
    
    diffs = np.diff(pareto_indices)
    jump_indices = np.where(diffs > threshold)[0]
    print(f"[Pareto Filter] Found {len(jump_indices)} jump(s) (threshold = {threshold})")
    
    if len(jump_indices) > 0:
        last_jump = jump_indices[-1]
        keep_indices = pareto_indices[last_jump + 1:]
        print(f"[Pareto Filter] Keeping {len(keep_indices)} Pareto points after last jump (index {pareto_indices[last_jump]})")
        is_pareto[:] = False
        is_pareto[keep_indices] = True
    else:
        print(f"[Pareto Filter] No large jump found. Keeping all {len(pareto_indices)} Pareto points.")
    
    return is_pareto


def rectangle_pt(x):
    return ((x > -0.5) & (x < 0.5)).to(x)


# ----------- Custom Operations -----------
class Step(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold):
        ctx.save_for_backward(x, threshold)
        return (x > threshold).to(x)
    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = 0.0 * grad_output
        threshold_grad = torch.sum(
            -(1.0 / BANDWIDTH) * rectangle_pt((x - threshold) / BANDWIDTH) * grad_output, dim=0
        )
        return x_grad, threshold_grad

class JumpReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold):
        ctx.save_for_backward(x, threshold)
        return x * (x > threshold).to(x)
    @staticmethod
    def backward(ctx, grad_output):
        x, threshold = ctx.saved_tensors
        x_grad = (x > threshold) * grad_output
        threshold_grad = torch.sum(
            -(threshold / BANDWIDTH)
            * rectangle_pt((x - threshold) / BANDWIDTH)
            * grad_output, dim=0
        )
        return x_grad, threshold_grad

    
# ----------- SAE model -----------
class Sae(nn.Module):
    def __init__(self, sae_width, activations_size, threshold_init, use_pre_enc_bias):
        super().__init__()
        self.use_pre_enc_bias = use_pre_enc_bias
        W_enc = torch.empty((activations_size, sae_width))
        W_dec = torch.empty((sae_width, activations_size))
        nn.init.xavier_uniform_(W_enc)
        nn.init.xavier_uniform_(W_dec)
        self.W_enc = nn.Parameter(W_enc)
        self.W_dec = nn.Parameter(W_dec)
        self.b_enc = nn.Parameter(torch.zeros((sae_width,)))
        self.b_dec = nn.Parameter(torch.zeros((activations_size,)))
        self.log_threshold = nn.Parameter(torch.log(torch.full((sae_width,), threshold_init)))
    @property
    def threshold(self):
        return torch.exp(self.log_threshold)
    def encode(self, x):
        if self.use_pre_enc_bias:
            x = x - self.b_dec
        pre_acts = x @ self.W_enc + self.b_enc
        acts = (pre_acts > self.threshold) * torch.nn.functional.relu(pre_acts)
        return acts
    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec
    def forward(self, x):
        if self.use_pre_enc_bias:
            x = x - self.b_dec
        pre_acts = x @ self.W_enc + self.b_enc
        acts = JumpReLU.apply(pre_acts, self.threshold)
        x_recon = acts @ self.W_dec + self.b_dec
        return x_recon, pre_acts

    
# ----------- Training tool function -----------
def create_dataloader(dataset, batch_size):
    data_tensor = torch.stack(dataset)
    tensor_dataset = TensorDataset(data_tensor)
    return DataLoader(tensor_dataset, batch_size=batch_size, shuffle=True)

def remove_parallel_component_pt(x, v):
    v_norm = v / (torch.norm(v, dim=-1, keepdim=True) + 1e-6)
    proj = torch.einsum("...d,...d->...", x, v_norm)
    return x - proj[..., None] * v_norm

def loss_fn_pt(sae, x, sparsity_coefficient, use_pre_enc_bias):
    x_recon, pre_acts = sae(x)
    recon_loss = torch.sum((x - x_recon)**2, dim=-1)
    threshold = sae.threshold
    l0 = torch.sum(Step.apply(pre_acts, threshold), dim=-1)
    sparsity_loss = sparsity_coefficient * l0
    return torch.mean(recon_loss + sparsity_loss), torch.mean(recon_loss), torch.mean(l0)


# ----------- Main training function -----------
def train_pt(args):
    global BANDWIDTH
    BANDWIDTH = args.bandwidth
    device = "cuda" if torch.cuda.is_available() else "cpu"
    check_gpu_availability()

    # Automatically stitch paths
    save_dir = f"saved_models/{args.model_name}/layer_{args.layer}"
    dataset_path = f"data/Representation/{args.pos}/{args.model_name}/{args.dataset}/layer_{args.layer}.pt"
    print(f"Using activation path: {dataset_path}")
    print(f"Saving model to: {save_dir}")
    os.makedirs(save_dir, exist_ok=True)

    dataset = torch.load(dataset_path)
    dataset = [value for value in dataset.values()]
    activations_size = dataset[0].shape[0]
    sae_width = activations_size * args.scaling_factor
    print(f"SAE width = {sae_width} (scaling_factor={args.scaling_factor})")

    sae = Sae(sae_width, activations_size, args.threshold_init, args.use_pre_enc_bias).to(device)
    print(f"SAE model initialized successfully on device: {device}")
    optimizer = torch.optim.Adam(sae.parameters(), lr=args.learning_rate, betas=(args.adam_b1, args.adam_b2))
    print(f"Optimizer: Adam with learning rate {args.learning_rate}, betas=({args.adam_b1}, {args.adam_b2})")
    dataloader = create_dataloader(dataset, args.batch_size)
    print(f"DataLoader initialized: {len(dataloader)} batches of size {args.batch_size}")
    
    target_act = torch.stack(dataset)
    print(f"Target_act shape: {target_act.shape}")
    
    # Using gradient accumulation
    # Change the value 4096 in "min(4096, len(dataset)" to achieve logically larger batchsize
    accumulation_steps = math.ceil(min(4096, len(dataset)) / args.batch_size)
    print(f"Using physical batch size: {args.batch_size}")
    print(f"Using logical batch size: {accumulation_steps} * {args.batch_size} = {accumulation_steps * args.batch_size}")
    print(f"Gradient accumulation steps: {accumulation_steps}")
    logical_batch_num = math.ceil(len(dataset) / (accumulation_steps * args.batch_size))
    print(f"Logical batch num: {len(dataset)} / {accumulation_steps * args.batch_size} = {logical_batch_num}")
    
    
    best_balanced_model_path = None
    recon_losses = []
    sparsity_losses = []
    history = []

    save_every = 20  # How many steps to save once
    pareto_step_global = None  # Record and update pareto_step_global
    max_pareto_step_found = -1  # Record the maximum Pareto optimal step ever determined
    
    for step in range(args.max_steps):
        sae.train()
        total_recon_loss, total_sparsity_loss, total_loss = 0.0, 0.0, 0.0
        batch_losses = []
        
        # Initialize the logical batch id currently being processed
        cur_logical_batch = 1
        optimizer.zero_grad()
        pbar = tqdm(dataloader, desc=f"Step [{step+1}/{args.max_steps}]", leave=False)
        for batch_idx, (batch,) in enumerate(pbar):
            batch = batch.to(device)
            loss, recon_loss, sparsity_loss = loss_fn_pt(
                sae, batch, 
                args.sparsity_coefficient, args.use_pre_enc_bias
            )
            # Normalize loss to achieve gradient accumulation equivalence
            if cur_logical_batch != logical_batch_num:
                denom = accumulation_steps
                region = "MAIN LOGICAL BATCH"
            else:
                denom = len(dataloader) - (logical_batch_num-1) * accumulation_steps
                region = "LAST LOGICAL BATCH"
            (loss / denom).backward()
            
            total_recon_loss += recon_loss.item()
            total_sparsity_loss += sparsity_loss.item()
            total_loss += loss.item()
            batch_losses.append(loss.item())
            
            pbar.set_postfix({
                "Cur Logical Batch": f"{cur_logical_batch}",
                "Batch Loss": f"{loss.item():.4f}",
                "region": f"{region}",
                "denom": f"{denom}"
            })
            
            del loss, recon_loss, sparsity_loss, batch
            torch.cuda.empty_cache()
            
            # Update parameters every accumulation_steps
            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1 == len(dataloader)):
                cur_logical_batch += 1
                
                if args.fix_decoder_norms:
                    sae.W_dec.grad = remove_parallel_component_pt(sae.W_dec.grad, sae.W_dec.data)
                
                optimizer.step()
                optimizer.zero_grad()
                
                if args.fix_decoder_norms:
                    sae.W_dec.data = sae.W_dec.data / torch.norm(sae.W_dec.data, dim=-1, keepdim=True)
                torch.cuda.empty_cache()
            
        avg_loss = total_loss / len(dataloader)
        min_loss = min(batch_losses)
        max_loss = max(batch_losses)
        step_record = {
            "step": step + 1,
            "avg_loss": avg_loss,
            "min_loss": min_loss,
            "max_loss": max_loss,
            "avg_recon_loss": total_recon_loss / len(dataloader),
            "avg_sparsity_loss": total_sparsity_loss / len(dataloader),
        }
        
        # 打印每个 step 的统计信息
        print(f"Step [{step+1}/{args.max_steps}] Summary:")
        print(f"  - Average Loss: {avg_loss:.4f}")
        print(f"  - Average Recon Loss: {total_recon_loss / len(dataloader):.4f}")
        print(f"  - Average Sparsity Loss: {total_sparsity_loss / len(dataloader):.4f}")
        print(f"  - Min Batch Loss: {min_loss:.4f}")
        print(f"  - Max Batch Loss: {max_loss:.4f}")
        print("-" * 50)
        torch.cuda.empty_cache()

        if (step + 1) % save_every == 0:
            sae.eval()
            
            with torch.no_grad():
                total_sq_error = 0.0
                total_l0 = 0.0
                BATCH_SIZE_INFER = 4096  # Adjustable
                for start_idx in range(0, target_act.shape[0], BATCH_SIZE_INFER):
                    end_idx = min(start_idx + BATCH_SIZE_INFER, target_act.shape[0])
                    batch = target_act[start_idx:end_idx]
                    sae_acts_batch = sae.encode(batch.to(device))
                    recon_batch = sae.decode(sae_acts_batch)
                    total_sq_error += torch.sum((recon_batch.cpu() - batch.cpu()) ** 2)  # Accumulate the sum of squared errors
                    total_l0 += (sae_acts_batch.cpu() > 0).sum()  # 累加 l0 激活数
                    del batch, sae_acts_batch, recon_batch
                    torch.cuda.empty_cache()
                r2 = 1 - (total_sq_error / target_act.numel()) / target_act.var()
                avg_l0 = total_l0 / target_act.shape[0]
                
                step_record["r2"] = r2.item()
                step_record["avg_l0"] = avg_l0.item()
                # Calculate and save threshold statistics
                threshold_values = sae.threshold.detach().cpu().numpy()
                step_record["threshold"] = {
                    "mean": float(np.mean(threshold_values)),
                    "max": float(np.max(threshold_values)),
                    "min": float(np.min(threshold_values))
                }
                
                print(f"  - threshold (Step {step+1}): {float(np.mean(threshold_values)):.6f}")
                print(f"  - R2 score (Step {step+1}): {r2.item():.6f}")
                print(f"  - L0 score (Step {step+1}): {avg_l0.item():.6f}")
                
            # Recording History
            recon_losses.append(total_recon_loss / len(dataloader))
            sparsity_losses.append(total_sparsity_loss / len(dataloader))
            # Calculate Pareto front
            recon_losses_np = np.array(recon_losses)
            sparsity_losses_np = np.array(sparsity_losses)
            is_pareto = is_pareto_front(recon_losses_np, sparsity_losses_np, threshold=1000//save_every)
            # Normalize only on the Pareto front
            pareto_recon = recon_losses_np[is_pareto]
            pareto_sparsity = sparsity_losses_np[is_pareto]
            pareto_norm_recon = (pareto_recon - np.min(pareto_recon)) / (np.max(pareto_recon) - np.min(pareto_recon) + 1e-8)
            pareto_norm_sparsity = (pareto_sparsity - np.min(pareto_sparsity)) / (np.max(pareto_sparsity) - np.min(pareto_sparsity) + 1e-8)
            pareto_distances = np.sqrt(pareto_norm_recon**2 + pareto_norm_sparsity**2)
            # Select the point with the minimum distance on Pareto
            min_idx = np.argmin(pareto_distances)
            pareto_step_idx = np.arange(len(recon_losses_np))[is_pareto][min_idx]
            pareto_step_global = int((pareto_step_idx + 1) * save_every)
            step_record["pareto_step_global"] = pareto_step_global
            print(f"  - pareto_step_global (Step {step+1}): {pareto_step_global}")
            
            
            # If the step corresponding to the Pareto optimal point is later than the currently saved optimal point
            if pareto_step_global > max_pareto_step_found and (step+1 - pareto_step_global) < 1000:
                if avg_l0.item() < 200:  # Start saving when avg_l0.item() < 100
                    # Save the model of the current step
                    new_balanced_path = os.path.join(
                        save_dir,
                        f"x{args.scaling_factor}_SprCoef{args.sparsity_coefficient}_LR{args.learning_rate}_Los{avg_loss:.4f}_L{avg_l0.item():.4f}_R{r2.item():.4f}_BalancedBest_step{step+1}.pt"
                    )
                    torch.save(sae.state_dict(), new_balanced_path)
                    print(f"New best balanced model saved at: {new_balanced_path}")
                    if best_balanced_model_path and os.path.exists(best_balanced_model_path):
                        os.remove(best_balanced_model_path)
                    best_balanced_model_path = new_balanced_path
                    
                max_pareto_step_found = step+1  # Update the "recorded optimal step" to the current step
                
        history.append(step_record)
        
        # Early Stopping
        patience = max(int(4000/logical_batch_num), 1500)  # Maximum number of steps allowed to lag behind the optimal point of pareto
        train_at_least = max(int(12000/logical_batch_num), 3000)   # Minimum number of training steps
        if (step + 1) > train_at_least and pareto_step_global is not None and (step + 1) - pareto_step_global >= patience:
            print(f"\n✅ Stopping early at step {step+1}, as it exceeds patience threshold beyond Pareto-optimal step {pareto_step_global}.")
            break

    log_path = os.path.join(save_dir, f"x{args.scaling_factor}_SprCoef{args.sparsity_coefficient}_LR{args.learning_rate}_training_log.json")
    with open(log_path, "w") as f:
        json.dump(history, f, indent=4)
    print(f"Training log saved at: {log_path}")
    print("Training completed successfully.")

    
# ----------- CLI Entrance -----------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SAEs with command-line hyperparameters.")
    parser.add_argument("--model_name", type=str, required=True, help="Model name (e.g., google/gemma-2-2b)")
    parser.add_argument("--layer", type=int, required=True, help="Layer number")
    parser.add_argument("--dataset", type=str, choices=["Counterfact"], default="Counterfact", help="Dataset")
    parser.add_argument("--pos", type=str, default="down_proj_input", help="Activation position")
    parser.add_argument("--sparsity_coefficient", type=float, default=0.1, help="Sparsity coefficient")
    parser.add_argument("--use_pre_enc_bias", type=lambda x: x.lower() == 'true', default=True, help="Use pre encoder bias")
    parser.add_argument("--fix_decoder_norms", type=lambda x: x.lower() == 'true', default=True, help="Fix decoder norms")
    parser.add_argument("--max_steps", type=int, default=18000, help="Number of max steps")
    parser.add_argument("--scaling_factor", type=int, default=4, help="Scaling factor")
    parser.add_argument("--batch_size", type=int, default=4096, help="Batch size")
    parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
    parser.add_argument("--threshold_init", type=float, default=0.001, help="Threshold init value")
    parser.add_argument("--bandwidth", type=float, default=2.0, help="Bandwidth for custom backward")
    parser.add_argument("--adam_b1", type=float, default=0.9, help="Adam beta1")
    parser.add_argument("--adam_b2", type=float, default=0.999, help="Adam beta2")
    args = parser.parse_args()

    train_pt(args)