import os
import sys
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import argparse
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import pathlib
import pickle
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

from models import PruneLlama2ForCausalLM, PruneLlama2DecoderLayer
from pruning import collect_info_reg_llama, help_functions_hn
from pruning.dyn_hypernetwork import dyn_hypernetwork
from lib.dataset_loader import build_wikitext_ids, sample_wikitext_sequences, calculate_perplexity, calculate_perplexity_with_label, read_lrp_file, read_mixed_lrp_file, evaluate_mc_example, calculate_sequence_log_prob

import wandb

def compute_contrastive_loss(model, input_ids, tokenizer, device, original_example=None):
    if original_example is None or 'options' not in original_example:
        return calculate_perplexity_with_label(model, input_ids, None, device)
    
    try:
        dataset_name = original_example.get("dataset_name", "")
        
        # handle dataset type
        if "winogrande" in dataset_name.lower():
            ctx_pref = original_example["context_prefix"]
            tgt_suf = original_example["target_suffix"]
            options = original_example["options"]  # [" option1", " option2"]
            correct_idx = original_example["label"]
            
            option_log_probs = []
            for option in options:
                # build complete sequence
                full_ctx = ctx_pref + option
                ids_full = tokenizer(full_ctx + tgt_suf,
                                   add_special_tokens=False,
                                   return_tensors="pt").input_ids.to(device)
                ctx_len = len(tokenizer(full_ctx, add_special_tokens=False).input_ids)
                
                # forward pass, maintain gradients
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    logits = model(ids_full).logits
                
                # key: calculate log probabilities but maintain tensor format, don't convert to scalar
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = ids_full[:, 1:].contiguous()
                
                # calculate token-level log probabilities
                log_probs = F.log_softmax(shift_logits, dim=-1)
                token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
                
                # only calculate average log probability for target part (from ctx_len)
                if ctx_len > 0 and ctx_len - 1 < token_log_probs.shape[1]:
                    target_log_probs = token_log_probs[:, ctx_len-1:]
                else:
                    target_log_probs = token_log_probs
                
                # calculate average, maintain tensor format and gradients
                avg_log_prob = target_log_probs.mean()
                option_log_probs.append(avg_log_prob)
        
        else:
            # other datasets: ARC, HellaSwag, PIQA, etc.
            question = original_example["question"]
            options = original_example["options"]
            correct_idx = original_example["label"]
            
            option_log_probs = []
            for option_content in options:
                # build complete context
                full_text = f"{question} Answer: {option_content}"
                option_input_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device)
                
                # forward pass, maintain gradients
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    logits = model(option_input_ids).logits
                
                # calculate log probability for answer part
                question_len = len(tokenizer(question, add_special_tokens=True).input_ids)
                
                # maintain tensor format calculation
                shift_logits = logits[:, :-1, :].contiguous()
                shift_labels = option_input_ids[:, 1:].contiguous()
                
                log_probs = F.log_softmax(shift_logits, dim=-1)
                token_log_probs = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
                
                # only calculate average log probability for answer part
                if question_len > 0 and question_len - 1 < token_log_probs.shape[1]:
                    answer_log_probs = token_log_probs[:, question_len-1:]
                else:
                    answer_log_probs = token_log_probs
                
                # maintain tensor format and gradients
                avg_log_prob = answer_log_probs.mean()
                option_log_probs.append(avg_log_prob)
        
        # key: directly stack tensors, maintain gradients
        logits_tensor = torch.stack(option_log_probs)  # don't break gradients here!
        
        # calculate contrastive loss
        target = torch.tensor(correct_idx, device=device, dtype=torch.long)
        
        # option 1: directly use cross entropy (no softmax)
        contrastive_loss = F.cross_entropy(logits_tensor.unsqueeze(0), target.unsqueeze(0))
        
        return contrastive_loss
        
    except Exception as e:
        print(f"differentiable contrastive learning calculation failed: {e}")
        return calculate_perplexity_with_label(model, input_ids, None, device)

class PreprocessedDataset(Dataset):
    """preprocessed dataset, avoid repeated conversion"""
    def __init__(self, samples_data, param_reg_structures, device, max_samples=None, 
                 normalize_lrp=True, normalize_activations=True, data_type_filter=None):
        self.device = device
        self.normalize_lrp = normalize_lrp
        self.normalize_activations = normalize_activations
        self.samples = []
        
        print("Preprocessing dataset...")
        total_samples = len(samples_data)

        if max_samples is None:
            actual_samples = total_samples
            print(f"Using all data, total {total_samples} samples")
        elif max_samples >= total_samples:
            actual_samples = total_samples
            print(f"Specified max_samples({max_samples}) is greater than or equal to dataset samples ({total_samples}), using all {total_samples} samples")
        else:
            actual_samples = max_samples
            print(f"Selecting first {actual_samples} samples from {total_samples} samples")

        for idx in tqdm(range(actual_samples)):
            sample_data = samples_data[idx]
            
            # New: data type filtering
            if data_type_filter is not None:
                has_label = sample_data.get("label", None) is not None
                has_original_example = 'original_example' in sample_data
                has_options = (has_original_example and 
                              isinstance(sample_data['original_example'], dict) and
                              'options' in sample_data['original_example'])
                
                is_supervised = has_label and has_options
                
                if data_type_filter == "supervised" and not is_supervised:
                    continue
                elif data_type_filter == "unsupervised" and is_supervised:
                    continue
            
            # Process sample_ids
            sample_ids = sample_data["sample_id"]
            if isinstance(sample_ids, np.ndarray):
                sample_ids = torch.from_numpy(sample_ids).long()
            elif not isinstance(sample_ids, torch.Tensor):
                sample_ids = torch.tensor(sample_ids).long()

            label_pos = sample_data.get("label", None)
            if label_pos is not None:
                if isinstance(label_pos, np.ndarray):
                    label_pos = torch.from_numpy(label_pos).long()
                elif not isinstance(label_pos, torch.Tensor):
                    label_pos = torch.tensor(label_pos).long()

            # Process activations and lrp - keep on CPU, transfer when needed
            layer_activations = []
            input_lrp = []
            
            lrp_scores = sample_data["lrp"]
            activations = sample_data["activations"]
            
            for structure_idx in range(len(param_reg_structures)):
                if structure_idx < len(lrp_scores) and structure_idx < len(activations):
                    activation_data = activations[structure_idx]
                    lrp_data = lrp_scores[structure_idx]
                    
                    if isinstance(activation_data, np.ndarray):
                        activation_tensor = torch.from_numpy(activation_data).float()
                    else:
                        activation_tensor = torch.tensor(activation_data).float()
                    
                    if isinstance(lrp_data, np.ndarray):
                        lrp_tensor = torch.from_numpy(lrp_data).float()
                    else:
                        lrp_tensor = torch.tensor(lrp_data).float()
                    
                    if activation_tensor.dim() == 1:
                        activation_tensor = activation_tensor.unsqueeze(0)
                    if lrp_tensor.dim() == 1:
                        lrp_tensor = lrp_tensor.unsqueeze(0)
                    
                    if self.normalize_lrp:
                        lrp_tensor = self.normalize_tensor_layerwise(lrp_tensor)

                    if self.normalize_activations:
                        activation_tensor = self.normalize_tensor_layerwise(activation_tensor)

                    layer_activations.append(activation_tensor)
                    input_lrp.append(lrp_tensor)
            
            if sample_ids.dim() == 1:
                sample_ids = sample_ids.unsqueeze(0)
            
            sample_dict = {
                'sample_ids': sample_ids,
                'layer_activations': layer_activations,
                'input_lrp': input_lrp,
                'label_pos': label_pos
            }
            if 'original_example' in sample_data:
                sample_dict['original_example'] = sample_data['original_example']
            self.samples.append(sample_dict)
    
    # Remaining methods remain unchanged
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        result = {
            'sample_ids': sample['sample_ids'].to(self.device),
            'layer_activations': [act.to(self.device) for act in sample['layer_activations']],
            'input_lrp': [lrp.to(self.device) for lrp in sample['input_lrp']],
            'label_pos': sample['label_pos']
        }
        if 'original_example' in sample:
            result['original_example'] = sample['original_example']
        
        if result['label_pos'] is not None and isinstance(result['label_pos'], torch.Tensor):
            result['label_pos'] = result['label_pos'].to(self.device)
        
        return result
    
    def normalize_tensor_layerwise(self, tensor, eps=1e-8):
        if tensor.numel() == 0:
            return tensor
        
        tensor = torch.abs(tensor)
        
        if tensor.dim() == 1:
            tensor = tensor.unsqueeze(0)
            squeeze_later = True
        else:
            squeeze_later = False
        
        mean = tensor.mean(dim=-1, keepdim=True)
        std = tensor.std(dim=-1, keepdim=True, unbiased=False)
        std = torch.clamp(std, min=eps)
        
        normalized_tensor = (tensor - mean) / std
        
        if squeeze_later:
            normalized_tensor = normalized_tensor.squeeze(0)
        
        return normalized_tensor

def collate_fn(batch):
    """Custom batching function"""
    # Process batch data
    batch_sample_ids = torch.cat([item['sample_ids'] for item in batch], dim=0)

    batch_label_pos = [item.get('label_pos', None) for item in batch]
    
    # Process layer_activations and input_lrp
    num_layers = len(batch[0]['layer_activations'])
    batch_layer_activations = []
    batch_input_lrp = []
    
    for layer_idx in range(num_layers):
        # Collect all activations for this layer
        layer_acts = [item['layer_activations'][layer_idx] for item in batch]
        layer_lrps = [item['input_lrp'][layer_idx] for item in batch]
        
        # Concatenate batches
        batch_layer_activations.append(torch.cat(layer_acts, dim=0))
        batch_input_lrp.append(torch.cat(layer_lrps, dim=0))

    batch_original_examples = [item.get('original_example', None) for item in batch]
    
    return {
        'sample_ids': batch_sample_ids,
        'layer_activations': batch_layer_activations,
        'input_lrp': batch_input_lrp,
        'label_pos': batch_label_pos,
        'original_examples': batch_original_examples
    }


def get_batch_with_contrastive_loss(model, input_ids, label_pos, original_examples, tokenizer, device):
    """
    Calculate batch loss, automatically selecting contrastive learning or traditional method
    """
    batch_losses = []
    contrastive_count = 0
    
    for i in range(input_ids.size(0)):
        single_input = input_ids[i:i+1]
        single_label_pos = label_pos[i] if isinstance(label_pos, list) else label_pos
        single_original_example = original_examples[i] if isinstance(original_examples, list) else None
        
        # Check if contrastive learning can be used
        if (single_original_example is not None and 
            tokenizer is not None and 
            isinstance(single_original_example, dict) and
            'options' in single_original_example):
            
            # Use contrastive learning loss
            try:
                contrastive_loss = compute_contrastive_loss(
                    model, single_input, tokenizer, device, single_original_example
                )
                batch_losses.append(contrastive_loss)
                contrastive_count += 1
            except Exception:
                raise Exception("Contrastive learning calculation failed")
        else:
            raise Exception("Contrastive learning calculation failed, no original_example")
    
    return torch.stack(batch_losses).mean(), contrastive_count

def train_dyn_hypernetwork(
    model,
    hypernetwork,
    param_reg,
    samples_data,
    args,
    supervised_dataset=None,
    unsupervised_dataset=None,
    hn_helper=None
):
    device = args.device
    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    # Create DataLoaders
    supervised_dataloader = None
    unsupervised_dataloader = None
    
    if supervised_dataset is not None and len(supervised_dataset) > 0:
        supervised_dataloader = DataLoader(
            supervised_dataset,
            batch_size=getattr(args, 'supervised_batch_size', args.batch_size),
            shuffle=True,
            collate_fn=collate_fn,
            pin_memory=False,
            num_workers=0
        )
    
    if unsupervised_dataset is not None and len(unsupervised_dataset) > 0:
        unsupervised_dataloader = DataLoader(
            unsupervised_dataset,
            batch_size=getattr(args, 'unsupervised_batch_size', args.batch_size),
            shuffle=True,
            collate_fn=collate_fn,
            pin_memory=False,
            num_workers=0
        )
    
    # Calculate total batch count (for scheduler)
    total_batches_per_epoch = 0
    if supervised_dataloader:
        total_batches_per_epoch = max(total_batches_per_epoch, len(supervised_dataloader))
    if unsupervised_dataloader:
        total_batches_per_epoch = max(total_batches_per_epoch, len(unsupervised_dataloader))
    
    # Create optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(hypernetwork.parameters(), lr=args.lr, weight_decay=0.05)
    if args.use_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.max_epochs * total_batches_per_epoch,
            eta_min=args.min_lr
        )
    
    # Create gradient scaler
    scaler = GradScaler()
    
    # Initialize best model and best results
    best_ppl = float('inf')
    best_state_dict = None
    
    # Freeze main model parameters
    for param in model.parameters():
        param.requires_grad = False
    for param in hypernetwork.parameters():
        param.requires_grad = True
    hypernetwork.train()

    # Training loop - alternate epochs training
    global_step = 0
    for epoch in range(args.max_epochs):
        epoch_losses = []

        if epoch > 0: 
            torch.cuda.empty_cache()
            gc.collect()
        
        # Decide which data to train based on epoch number
        if epoch >= 2 and supervised_dataloader:
            # Train on supervised data for the first 2 epochs
            current_dataloader = supervised_dataloader
            data_type = "supervised"
            print(f"\nEpoch {epoch+1}: Training supervised data (batch count: {len(supervised_dataloader)})")
        elif epoch < 2 and unsupervised_dataloader:
            # Train on unsupervised data for the last 2 epochs
            current_dataloader = unsupervised_dataloader
            data_type = "unsupervised"
            print(f"\nEpoch {epoch+1}: Training unsupervised data (batch count: {len(unsupervised_dataloader)})")
        else:
            # If corresponding data type does not exist, use the other
            if supervised_dataloader:
                current_dataloader = supervised_dataloader
                data_type = "supervised"
                print(f"\nEpoch {epoch+1}: Falling back to supervised data")
            elif unsupervised_dataloader:
                current_dataloader = unsupervised_dataloader
                data_type = "unsupervised"
                print(f"\nEpoch {epoch+1}: Falling back to unsupervised data")
            else:
                print(f"\nEpoch {epoch+1}: No data available, skipping")
                continue
        
        # Use tqdm to show progress
        progress_bar = tqdm(enumerate(current_dataloader), total=len(current_dataloader), 
                           desc=f"Epoch {epoch+1}/{args.max_epochs} ({data_type})")
        
        for batch_idx, batch in progress_bar:
            # Set progress for hypernetwork
            progress = (epoch * total_batches_per_epoch + batch_idx) / (args.max_epochs * total_batches_per_epoch)
            if hasattr(hypernetwork, "module"):
                hypernetwork.module._prog = progress
            else:
                hypernetwork._prog = progress
            
            # Get batch data
            input_ids = batch['sample_ids']
            layer_activations = batch['layer_activations']
            input_lrp = batch['input_lrp']
            label_pos = batch.get('label_pos', [None] * input_ids.size(0))
            original_examples = batch.get('original_examples', [None] * input_ids.size(0))
            
            optimizer.zero_grad()
            
            # Use mixed precision training
            with torch.cuda.amp.autocast(dtype=torch_dtype):
                # Generate mask vectors
                masks = hypernetwork(layer_activations, input_lrp)
                
                # Set masks
                if hn_helper is None:
                    hn_helper = help_functions_hn(param_reg.structures)
                hn_helper.set_gate_vectors(model, masks)
                
                # Calculate mask retention rate
                gate_mean = torch.stack([m.mean() for m in masks]).mean().item()
                
                # Forward pass
                output = model(input_ids)
                logits = output.logits if hasattr(output, 'logits') else output
                
                # Calculate loss based on data type
                if data_type == "supervised":
                    # Supervised data uses contrastive learning loss
                    try:
                        ce_loss, contrastive_count = get_batch_with_contrastive_loss(
                            model, input_ids, label_pos, original_examples, tokenizer, device
                        )
                    except Exception as e:
                        print(f"Contrastive learning loss calculation failed, falling back to traditional method: {e}")
                        # Fallback to traditional loss
                        shift_logits = logits[:, :-1, :].contiguous()
                        shift_labels = input_ids[:, 1:].contiguous()
                        log_probs = F.log_softmax(shift_logits, dim=-1)
                        token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
                        ce_loss = -token_log_probs.mean()
                else:
                    # Unsupervised data uses full sequence PPL loss
                    shift_logits = logits[:, :-1, :].contiguous()
                    shift_labels = input_ids[:, 1:].contiguous()
                    log_probs = F.log_softmax(shift_logits, dim=-1)
                    token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
                    ce_loss = -token_log_probs.mean()
                
                # Calculate regularization term
                if hasattr(hypernetwork, "module"):
                    hard_out = hypernetwork.module.hard_output(layer_activations, input_lrp)
                else:
                    hard_out = hypernetwork.hard_output(layer_activations, input_lrp)
                
                reg_loss = torch.tensor(0.0).to(device)
                if args.reg_weight > 0:
                    reg_loss = param_reg(hard_out)
                
                # Total loss
                total_loss = ce_loss + args.reg_weight * reg_loss
            
            # Backward pass
            scaler.scale(total_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), args.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            
            # Update learning rate
            if args.use_scheduler:
                scheduler.step()
            
            # Record loss
            epoch_losses.append(total_loss.item())
            
            # Update progress bar
            progress_bar.set_postfix({
                'Type': data_type[:4],  # Display data type
                'Loss': f'{total_loss.item():.4f}',
                'CE': f'{ce_loss.item():.4f}',
                'Reg': f'{reg_loss.item():.4f}',
                'Keep': f'{gate_mean:.2f}'
            })

            if hasattr(args, 'use_wandb') and args.use_wandb:
                current_lr = optimizer.param_groups[0]['lr'] if args.use_scheduler else args.lr
                wandb.log({
                    "train/total_loss": total_loss.item(),
                    "train/ce_loss": ce_loss.item(), 
                    "train/reg_loss": reg_loss.item(),
                    "train/gate_mean": gate_mean,
                    "train/learning_rate": current_lr,
                    "train/epoch": epoch,
                    "train/global_step": global_step,
                    "train/data_type": data_type
                })

            global_step += 1
        
        # Calculate average epoch loss
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0
        print(f"Epoch {epoch+1}/{args.max_epochs} - Avg Loss: {avg_epoch_loss:.4f} - Data Type: {data_type}")

        torch.cuda.empty_cache()
        gc.collect()

        # Add epoch-level wandb logs
        if hasattr(args, 'use_wandb') and args.use_wandb:
            wandb.log({
                "epoch/avg_loss": avg_epoch_loss,
                "epoch/epoch": epoch,
                "epoch/data_type": data_type
            })

        # Evaluate after each epoch
        if (epoch + 1) % args.eval_epochs == 0 and epoch >= args.start_epoch and epoch < args.max_epochs - 1:
            print("\nStarting evaluation...")
            # Use original evaluation logic, need to pass the combined dataset
            eval_dataset = supervised_dataset if supervised_dataset is not None else unsupervised_dataset
            if supervised_dataset is not None and unsupervised_dataset is not None:
                # If both exist, need to merge for evaluation, simplified here
                eval_dataset = supervised_dataset
                
            test_ppl = evaluate_hypernetwork(
                model=model,
                hypernetwork=hypernetwork,
                dataset=eval_dataset,
                param_reg=param_reg,
                hn_helper=hn_helper,
                args=args,
                max_eval_samples=None
            )
            
            print(f"Test PPL: {test_ppl:.2f}")

            if hasattr(args, 'use_wandb') and args.use_wandb:
                wandb.log({
                    "eval/test_ppl": test_ppl,
                    "eval/epoch": epoch
                })
            
            # Save best model
            if test_ppl < best_ppl:
                best_ppl = test_ppl
                if hasattr(hypernetwork, "module"):
                    best_state_dict = hypernetwork.module.state_dict()
                else:
                    best_state_dict = hypernetwork.state_dict()
                
                # Save checkpoint
                checkpoint = {
                    'epoch': epoch,
                    'hypernetwork': best_state_dict,
                    'optimizer': optimizer.state_dict(),
                    'test_ppl': best_ppl,
                }
                torch.save(checkpoint, os.path.join(args.output_dir, f"best_hypernetwork.pt"))
                print(f"Saved best model, PPL: {best_ppl:.2f}")

            # Reset to training mode
            hypernetwork.train()
    
    # After training, load best model
    if best_state_dict is not None:
        if hasattr(hypernetwork, "module"):
            hypernetwork.module.load_state_dict(best_state_dict)
        else:
            hypernetwork.load_state_dict(best_state_dict)
    
    print(f"\nTraining completed!")
    print(f"Best PPL: {best_ppl:.2f}")
    
    return hypernetwork, best_ppl

def evaluate_hypernetwork(model, hypernetwork, dataset, param_reg, hn_helper, args, max_eval_samples=None):
    """Optimized evaluation function"""
    hypernetwork.eval()
    model.eval()
    device = args.device
    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    
    # Create evaluation data loader
    eval_indices = list(range(len(dataset)))
    if max_eval_samples is not None:
        eval_indices = np.random.choice(eval_indices, size=min(max_eval_samples, len(dataset)), replace=False)
    
    eval_dataset = torch.utils.data.Subset(dataset, eval_indices)
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=args.eval_batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    total_loss = 0.0  # Changed variable name
    total_count = 0   # Added counter
    
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluation"):
            input_ids = batch['sample_ids']
            layer_activations = batch['layer_activations']
            input_lrp = batch['input_lrp']
            label_pos = batch.get('label_pos', [None] * input_ids.size(0))
            
            # Add autocast context for consistency
            with torch.cuda.amp.autocast(dtype=torch_dtype):
                masks = hypernetwork(layer_activations, input_lrp)
            
            # Calculate loss for each sample in the batch
            for i in range(input_ids.size(0)):
                # Get single sample's mask and data
                single_masks = [m[i:i+1] for m in masks]
                single_input = input_ids[i:i+1]
                single_label_pos = label_pos[i] if isinstance(label_pos, list) else label_pos

                # Apply mask
                hn_helper.set_gate_vectors(model, single_masks)
                hn_helper.set_gate_status(model, use_gate=True)
                
                # Calculate loss (considering label)
                loss = calculate_perplexity_with_label(model, single_input, single_label_pos, device)
                total_loss += loss.item()
                
                # Calculate effective token count
                if single_label_pos is not None:
                    start_pos = max(0, int(single_label_pos) - 1) if isinstance(single_label_pos, torch.Tensor) else max(0, single_label_pos - 1)
                    effective_length = max(1, single_input.size(1) - 1 - start_pos)  # -1 because of shift
                else:
                    effective_length = single_input.size(1) - 1  # -1 because of shift
                
                total_count += effective_length

    # Calculate average perplexity
    avg_loss = total_loss / len(eval_indices) if len(eval_indices) > 0 else float('inf')
    avg_ppl = torch.exp(torch.tensor(avg_loss))
    
    # Restore model state
    hn_helper.set_gate_status(model, use_gate=False)
    
    return avg_ppl


def main(args):
    """Main function"""
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    if args.use_wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.wandb_name,
            tags=args.wandb_tags,
            config=vars(args)
        )

    # Set device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    
    # Set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Load tokenizer and model
    print(f"Loading model: {args.model_path}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = PruneLlama2ForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16,
        device_map=device
    )
    model.config.use_cache = False
    model.eval()
    
    # Get parameter regularization structure
    param_reg = collect_info_reg_llama(model, p=1.0-args.target_sparsity, lam=1)
    
    # Load pre-calculated LRP scores and activations
    samples_data = read_mixed_lrp_file()
    
    print(f"Loaded {len(samples_data)} samples")
    
    # Initialize hypernetwork
    hypernetwork = dyn_hypernetwork(
        t_structures=param_reg.structures,
        lrp_scale=args.lrp_scale,
        base=args.base,
        T_start=args.T_start,
        T_end=args.T_end,
        target_sparsity=args.target_sparsity,
        hidden_dim=args.hidden_dim
    ).to(device)
    
    # Create hypernetwork helper
    hn_helper = help_functions_hn(param_reg.structures)

    # Create separate datasets
    supervised_dataset = PreprocessedDataset(
        samples_data, param_reg.structures, device, 
        max_samples=args.max_samples, 
        normalize_lrp=args.normalize_lrp, 
        normalize_activations=args.normalize_activations,
        data_type_filter="supervised"
    )
    
    unsupervised_dataset = PreprocessedDataset(
        samples_data, param_reg.structures, device, 
        max_samples=args.max_samples, 
        normalize_lrp=args.normalize_lrp, 
        normalize_activations=args.normalize_activations,
        data_type_filter="unsupervised"
    )
    
    print(f"Supervised samples: {len(supervised_dataset)}")
    print(f"Unsupervised samples: {len(unsupervised_dataset)}")

    # Train hypernetwork
    print("Starting hypernetwork training...")
    trained_hypernetwork, best_ppl = train_dyn_hypernetwork(
        model=model,
        hypernetwork=hypernetwork,
        param_reg=param_reg,
        samples_data=samples_data,
        args=args,
        supervised_dataset=supervised_dataset,
        unsupervised_dataset=unsupervised_dataset,
        hn_helper=hn_helper
    )
    
    # Save final model
    final_state = trained_hypernetwork.state_dict() if not hasattr(trained_hypernetwork, "module") else trained_hypernetwork.module.state_dict()
    torch.save(final_state, os.path.join(args.output_dir, "final_hypernetwork.pt"))
    print(f"Final model saved to {os.path.join(args.output_dir, 'final_hypernetwork.pt')}")
    
    # Create full dataset for final evaluation
    print("\nStarting final evaluation...")
    
    # Merge datasets for evaluation
    full_dataset = PreprocessedDataset(
        samples_data, param_reg.structures, device, 
        max_samples=args.max_samples, 
        normalize_lrp=args.normalize_lrp, 
        normalize_activations=args.normalize_activations
    )
    
    # Sample-wise evaluation
    evaluate_sample_wise(
        model=model,
        hypernetwork=trained_hypernetwork,
        dataset=full_dataset,
        param_reg=param_reg,
        hn_helper=hn_helper,
        args=args,
        tokenizer=tokenizer
    )

def evaluate_sample_wise(model, hypernetwork, dataset, param_reg, hn_helper, args, tokenizer=None):
    """Sample-wise evaluation: MC data uses accuracy, continuous text uses PPL"""
    hypernetwork.eval()
    model.eval()
    device = args.device
    torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    
    # MC data statistics
    mc_original_correct = 0
    mc_masked_correct = 0
    mc_total = 0
    mc_detailed_results = []
    
    # Continuous text statistics (WikiText, etc.) - modified to cumulative calculation
    text_original_nll, text_original_tokens = 0.0, 0
    text_masked_nll, text_masked_tokens = 0.0, 0
    text_sample_count = 0
    
    # Overall statistics
    total_samples = 0
    
    with torch.no_grad():
        for i in tqdm(range(len(dataset)), desc="Sample evaluation"):
            sample = dataset[i]
            total_samples += 1
            
            input_ids = sample['sample_ids']
            layer_activations = sample['layer_activations']
            input_lrp = sample['input_lrp']
            label_pos = sample.get('label_pos', None)

            # 1. Generate masks (needed for all samples)
            with torch.cuda.amp.autocast(dtype=torch_dtype):
                masks = hypernetwork(layer_activations, input_lrp)
            
            # 2. Evaluate based on sample type
            if (label_pos is not None and tokenizer is not None and 'original_example' in sample and
    isinstance(sample['original_example'], dict) and 'options' in sample['original_example']):
                # =============== MC data evaluation ===============
                try:
                    formatted_example = sample['original_example']
                    formatted_example["dataset_name"] = formatted_example.get("dataset_name", "")
                    
                    # Original MC accuracy (no mask)
                    hn_helper.set_gate_status(model, use_gate=False)
                    original_result = evaluate_mc_example(model, tokenizer, formatted_example, device=device)
                    original_correct = original_result["is_correct"]
                    
                    # MC accuracy after masking
                    hn_helper.set_gate_vectors(model, masks)
                    hn_helper.set_gate_status(model, use_gate=True)
                    masked_result = evaluate_mc_example(model, tokenizer, formatted_example, device=device)
                    masked_correct = masked_result["is_correct"]
                    
                    # Record MC results
                    mc_total += 1
                    if original_correct:
                        mc_original_correct += 1
                    if masked_correct:
                        mc_masked_correct += 1
                    
                    # Store detailed results
                    mc_detailed_results.append({
                        "sample_id": i,
                        "original_correct": original_correct,
                        "masked_correct": masked_correct,
                        "improvement": int(masked_correct) - int(original_correct),
                        "dataset": formatted_example.get("dataset_name", ""),
                        "question_type": "MC"
                    })
                    
                except Exception as e:
                    print(f"MC sample {i} evaluation failed: {e}")
                    continue
                
            else:
                # =============== Continuous text evaluation (PPL) - modified to cumulative calculation ===============
                try:
                    # Get actual sequence length
                    actual_seq_len = input_ids.shape[1]
                    limit_length = min(args.seqlen, actual_seq_len)
                    
                    # Calculate effective token count (number of tokens in each sequence - 1, because shift is needed)
                    effective_tokens = limit_length - 1
                    
                    # Original NLL (no mask)
                    hn_helper.set_gate_status(model, use_gate=False)
                    original_nll = calculate_perplexity(
                        model, input_ids, limit_length=limit_length, device=device
                    )
                    
                    # NLL after masking
                    hn_helper.set_gate_vectors(model, masks)
                    hn_helper.set_gate_status(model, use_gate=True)
                    masked_nll = calculate_perplexity(
                        model, input_ids, limit_length=limit_length, device=device
                    )
                        
                    # Cumulate NLL and token count
                    text_original_nll += original_nll.item()
                    text_masked_nll += masked_nll.item()
                    text_original_tokens += effective_tokens
                    text_masked_tokens += effective_tokens
                    text_sample_count += 1

                except Exception as e:
                    print(f"Text sample {i} PPL evaluation failed: {e}")
                    continue
            
            # Reset model state
            hn_helper.set_gate_status(model, use_gate=False)
    
    # =============== Result summary and reporting ===============
    print(f"\n=== Sample Evaluation Results ===")
    print(f"Total samples: {total_samples}")
    
    # MC data results
    if mc_total > 0:
        mc_original_accuracy = mc_original_correct / mc_total
        mc_masked_accuracy = mc_masked_correct / mc_total
        mc_accuracy_improvement = mc_masked_accuracy - mc_original_accuracy
        
        print(f"\n--- MC Data Results ---")
        print(f"MC samples: {mc_total}")
        print(f"Original MC accuracy: {mc_original_accuracy:.4f} ({mc_original_correct}/{mc_total})")
        print(f"Masked MC accuracy: {mc_masked_accuracy:.4f} ({mc_masked_correct}/{mc_total})")
        print(f"MC accuracy improvement: {mc_accuracy_improvement:+.4f} ({mc_accuracy_improvement*100:+.2f}%)")
        
        # MC detailed analysis
        mc_improved = [r for r in mc_detailed_results if r["improvement"] > 0]
        mc_degraded = [r for r in mc_detailed_results if r["improvement"] < 0]
        mc_unchanged = [r for r in mc_detailed_results if r["improvement"] == 0]
        
        print(f"MC improved samples: {len(mc_improved)} ({len(mc_improved)/mc_total*100:.1f}%)")
        print(f"MC degraded samples: {len(mc_degraded)} ({len(mc_degraded)/mc_total*100:.1f}%)")
        print(f"MC unchanged samples: {len(mc_unchanged)} ({len(mc_unchanged)/mc_total*100:.1f}%)")
    else:
        print("--- No MC data ---")
        mc_original_accuracy = 0
        mc_masked_accuracy = 0
        mc_accuracy_improvement = 0
    
    # Continuous text results - modified to unified PPL calculation
    if text_sample_count > 0 and text_original_tokens > 0 and text_masked_tokens > 0:
        # Calculate PPL as per reference code
        avg_original_ppl = torch.exp(torch.tensor(text_original_nll / text_original_tokens))
        avg_masked_ppl = torch.exp(torch.tensor(text_masked_nll / text_masked_tokens))
        ppl_improvement = avg_original_ppl.item() - avg_masked_ppl.item()  # PPL decrease is improvement
        
        print(f"\n--- Continuous Text Results (PPL) ---")
        print(f"Text samples: {text_sample_count}")
        print(f"Total tokens: {text_original_tokens}")
        print(f"Original average PPL: {avg_original_ppl.item():.2f}")
        print(f"Masked average PPL: {avg_masked_ppl.item():.2f}")

    else:
        print("--- No continuous text data ---")
        avg_original_ppl = 0
        avg_masked_ppl = 0
        ppl_improvement = 0
    
    # Data type distribution
    print(f"\n--- Data Distribution ---")
    print(f"MC data: {mc_total} ({mc_total/total_samples*100:.1f}%)")
    print(f"Continuous text: {text_sample_count} ({text_sample_count/total_samples*100:.1f}%)")
    print(f"Other/Failed: {total_samples - mc_total - text_sample_count}")
    
    
    return {
        "total_samples": total_samples,
        "mc_results": {
            "total": mc_total,
            "original_accuracy": mc_original_accuracy,
            "masked_accuracy": mc_masked_accuracy,
            "improvement": mc_accuracy_improvement,
            "detailed_results": mc_detailed_results
        },
        "text_results": {
            "total": text_sample_count,
            "original_ppl": float(avg_original_ppl.item()) if text_sample_count > 0 else 0,
            "masked_ppl": float(avg_masked_ppl.item()) if text_sample_count > 0 else 0,
            "improvement": float(ppl_improvement) if text_sample_count > 0 else 0
        }
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train dynamic hypernetwork and evaluate")

    # Model parameters
    parser.add_argument("--model_path", type=str, default="xxx/llms/meta/Llama-2-7B-hf", 
                        help="Model path")
    parser.add_argument("--device", type=str, default="cuda:1", 
                        help="Training device")
    parser.add_argument("--output_dir", type=str, default="xxx/project/DynPrune/llama-2-7b/041/hn", 
                        help="Output directory")
    parser.add_argument("--lrp_path", type=str, default="xxx/project/DISP/wikitext/lrp_train_ppl.pkl",
                        help="Pre-calculated LRP scores path")
    parser.add_argument("--normalize_lrp", type=bool, default=True, 
                        help="Whether to normalize LRP scores layer-wise")
    parser.add_argument("--normalize_activations", type=bool, default=False, 
                        help="Whether to normalize activations layer-wise")
    
    # Hypernetwork parameters
    parser.add_argument("--hidden_dim", type=int, default=128, 
                        help="Hypernetwork hidden dimension")
    parser.add_argument("--lrp_scale", type=float, default=1.0, 
                        help="LRP scaling factor")
    parser.add_argument("--base", type=float, default=0.5, 
                        help="Base value")
    parser.add_argument("--T_start", type=float, default=0.5, 
                        help="Starting temperature")
    parser.add_argument("--T_end", type=float, default=0.1, 
                        help="Ending temperature")
    parser.add_argument("--target_sparsity", type=float, default=0.4, 
                        help="Target sparsity")
    
    # Training parameters
    parser.add_argument("--start_epoch", type=int, default=8,
                        help="Epoch to start evaluation")
    parser.add_argument("--max_epochs", type=int, default=4, 
                        help="Maximum training epochs")
    parser.add_argument("--batch_size", type=int, default=1, 
                        help="Training batch size")
    parser.add_argument("--eval_batch_size", type=int, default=1, 
                        help="Evaluation batch size")
    parser.add_argument("--lr", type=float, default=2e-4, 
                        help="Learning rate")
    parser.add_argument("--min_lr", type=float, default=1e-5, 
                        help="Minimum learning rate")
    parser.add_argument("--use_scheduler", action="store_true", default=True,
                        help="Whether to use learning rate scheduler")
    parser.add_argument("--grad_clip", type=float, default=1.0, 
                        help="Gradient clipping value")
    parser.add_argument("--eval_epochs", type=int, default=3, 
                        help="Evaluate every N epochs")
    parser.add_argument("--seed", type=int, default=58, 
                        help="Random seed")


    parser.add_argument("--lam", type=float, default=4.0, 
                        help="Regularization strength parameter")
    parser.add_argument("--reg_weight", type=float, default=1.0, 
                        help="Regularization weight")
    
    
    # Data parameters
    parser.add_argument("--seqlen", type=int, default=2048, 
                        help="Sequence length")
    parser.add_argument("--max_samples", type=lambda x: None if x.lower() == 'none' else int(x), default=None, help="Maximum number of samples, can be integer or None")
    
    # wandb
    parser.add_argument("--use_wandb", type=bool, default=False, 
                        help="Whether to use wandb to record experiment")
    parser.add_argument("--wandb_project", type=str, default="hypernetwork-pruning", help="wandb project name")
    parser.add_argument("--wandb_name", type=str, default=None, 
                        help="wandb run name")
    parser.add_argument("--wandb_tags", type=str, nargs='+', default=[], 
                        help="wandb tags")

    args = parser.parse_args()
    
    torch.set_float32_matmul_precision('high')
    try:
        main(args)
    finally:
        # Ensure wandb is closed correctly
        if hasattr(args, 'use_wandb') and args.use_wandb:
            wandb.finish()