import sys
sys.path.append(".")

import argparse
import pandas as pd
from tqdm import tqdm
import os, gc
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import get_logger
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from data_modules.audio import AudioSlowRawDataset, AudioDataset
from helper import create_lr_scheduler, create_optimizer, find_all_linear_names
from sklearn.metrics import f1_score, accuracy_score
import matplotlib.pyplot as plt
from contextlib import nullcontext
import copy

logger = get_logger(name=__name__)

@dataclass
class TrainingArguments:
    model_name: str
    train_csv: str  # Path to training CSV file (path, emo columns)
    eval_csv: str   # Path to evaluation CSV file
    lora_rank: int
    lora_alpha: int
    lora_dropout: float
    learning_rate: float
    weight_decay: float
    optim: str
    lr_scheduler_type: str
    warmup_steps: int
    output_dir: str
    training_input_length: int
    stride_size: int
    num_train_epochs: int
    gradient_checkpointing: str
    use_flash_attention_2: bool
    logging_steps: int
    end_token_weight: float = 100  # Add loss weight parameter for audio end token, average audio length is 5k
    plot_loss_curve: bool = True  # Add new parameter
    classifier_learning_rate: float = 1e-4
    templora_inference_lr: float = 5e-5  # Add learning rate parameter for Temp LoRA during inference
    seed: int = 42  # Add random seed parameter, default value is 42

class AudioDatasetFromCSV(Dataset):
    def __init__(self, csv_path: str, tokenizer: PreTrainedTokenizer, prefix_length: int, stride_size: int, is_eval: bool = False):
        self.df = pd.read_csv(csv_path)
        self.tokenizer = tokenizer
        self.prefix_length = prefix_length
        self.stride_size = stride_size
        self.is_eval = is_eval  # Add flag
        self.datasets: List[AudioDataset] = []
        self._load_all_datasets()
        
    def _load_all_datasets(self):
        # Define valid emotions and mapping
        valid_emotions = {'hap', 'sad', 'ang', 'neu', 'exc'}
        emotion_mapping = {'exc': 'hap'}
        
        for _, row in self.df.iterrows():
            # Convert string format list to actual list object
            emotion_list = eval(row['emo'])
            
            # Filter and process emotion list
            filtered_emotions = []
            for timestamp, emotion in emotion_list:
                # Skip invalid emotions
                if emotion not in valid_emotions:
                    continue
                    
                # Map 'exc' to 'hap'
                mapped_emotion = emotion_mapping.get(emotion, emotion)
                filtered_emotions.append([timestamp, mapped_emotion])
            
            # Skip this row if filtered emotion list is empty
            if not filtered_emotions:
                continue
                
            # Chunk long audio
            raw_dataset = AudioSlowRawDataset(
                feature_path=row['path'],
                emotion_label=filtered_emotions,  # Pass processed emotion list
                tokenizer=self.tokenizer,
                prefix_length=self.prefix_length,
                stride_size=self.stride_size,
                is_eval=self.is_eval  # Pass flag
            )
            raw_dataset.load_from_npy()
            self.datasets.append(AudioDataset(dataset=raw_dataset, is_eval=self.is_eval))  # Pass flag
            
    def __len__(self) -> int:
        return sum(len(dataset) for dataset in self.datasets)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        # Find which dataset this index belongs to
        for dataset in self.datasets:
            if idx < len(dataset):
                return dataset[idx]
            idx -= len(dataset)
        raise IndexError("Index out of bounds")

def to_device(obj: Union[Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], device: torch.device) -> Union[Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]]:
    if isinstance(obj, list):
        return [
            {k: v.to(device=device) if hasattr(v, "to") else v for k, v in batch_dict.items()}
            for batch_dict in obj
        ]
    else:
        return {k: v.to(device=device) if hasattr(v, "to") else v for k, v in obj.items()}

def train_epoch(
    model: PreTrainedModel,
    train_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: Any,
    accelerator: Accelerator,
    epoch: int,
    args: TrainingArguments
) -> Tuple[float, List[float]]:
    # Create log file at the beginning of the function
    log_file = os.path.join(args.output_dir, "train_logs.txt")
    os.makedirs(args.output_dir, exist_ok=True)
    if not os.path.exists(log_file):
        with open(log_file, "w") as f:
            f.write("Training Logs\n")
            f.write("=============\n\n")

    model.train()
    total_loss = 0
    emotion_preds = []
    true_emotion_labels = []
    step_losses = []
    
    # Emotion token ID mapping
    emotion_id_to_name = {
        tokenizer.convert_tokens_to_ids("<emo_hap>"): "happy",
        tokenizer.convert_tokens_to_ids("<emo_sad>"): "sad",
        tokenizer.convert_tokens_to_ids("<emo_ang>"): "angry",
        tokenizer.convert_tokens_to_ids("<emo_neu>"): "neutral"
    }

    classifier_emotion_id_to_name = {
            0: "happy",
            1: "sad",
            2: "angry",
            3: "neutral"
        }

    emotion_ids = list(emotion_id_to_name.keys())  # Add this line to store all emotion token IDs
    
    progress_bar = tqdm(
        total=len(train_dataloader),
        desc=f"Epoch {epoch}",
        disable=not accelerator.is_local_main_process
    )
    
    # Initialize classifier and optimizer
    classifier = EmotionClassifier(model.config.hidden_size).to(accelerator.device)
    classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=args.classifier_learning_rate)
    classifier_criterion = torch.nn.CrossEntropyLoss()
    
    # Create classifier training log file
    classifier_train_log = os.path.join(args.output_dir, "classifier_train_logs.txt")
    
    # Add classifier prediction statistics
    classifier_preds = []
    classifier_true_labels = []
    
    for step, batch in enumerate(train_dataloader):
        batch = to_device(batch, accelerator.device)
        emotion_batch, audio_batch = batch
        batch_loss = 0
        
        # Handle audio token prediction
        input_ids = audio_batch["input_ids"].squeeze(0)
        # Input contains all audio tokens
        audio_input_ids = input_ids
        # Labels start from the second token and add end token at the end
        audio_labels = torch.cat([
            input_ids[:, 1:],
            torch.full((input_ids.size(0), 1), tokenizer.convert_tokens_to_ids("<audio_end>"), 
                      device=input_ids.device)
        ], dim=1)
        
        audio_outputs = model(
            input_ids=audio_input_ids,
            attention_mask=torch.ones_like(audio_input_ids),
            labels=audio_labels
        )
        batch_loss += audio_outputs.loss
        
        # Handle emotion prediction
        input_ids = emotion_batch["input_ids"].squeeze(0)
        has_end_token = 38565 in input_ids
        
        if has_end_token:
            emotion_input_ids = emotion_batch["input_ids"].squeeze(0)[:, :-1]  # Does not include emotion labels
            emotion_labels = emotion_batch["input_ids"].squeeze(0)[:, -1]      # Only take emotion labels
            
            emotion_outputs = model(
                input_ids=emotion_input_ids,
                attention_mask=torch.ones_like(emotion_input_ids),
                output_hidden_states=True
            )
            
            # Get the prediction result of the last token
            last_token_logits = emotion_outputs.logits[:, -1, :]  # [batch_size, vocab_size]
            
            # Calculate loss only for emotion tokens
            emotion_logits = last_token_logits[:, emotion_ids]
            true_emotion_idx = emotion_ids.index(emotion_labels.item())
            
            # Record prediction results and true labels
            pred_emotion_idx = torch.argmax(emotion_logits, dim=1)  # [batch_size]
            pred_emotion_id = emotion_ids[pred_emotion_idx.item()]
            true_emotion_id = emotion_labels.item()
            
            # Record emotion names instead of IDs
            emotion_preds.append(emotion_id_to_name[pred_emotion_id])
            true_emotion_labels.append(emotion_id_to_name[true_emotion_id])
            
            # Can add logging
            if len(emotion_preds) % 100 == 0:  # Record every 100 samples
                logger.info(f"Recent prediction: {emotion_preds[-1]}, True: {true_emotion_labels[-1]}")
            
            emotion_loss = torch.nn.CrossEntropyLoss()(
                emotion_logits,
                torch.tensor([true_emotion_idx], device=emotion_logits.device)
            )
            batch_loss += emotion_loss * args.end_token_weight
            

            # Get model hidden states
            with torch.no_grad():
                hidden_states = model.get_input_embeddings()(emotion_input_ids)
            
            # Train classifier
            classifier.train()
            logits = classifier(hidden_states)
            classifier_loss = classifier_criterion(
                logits,
                torch.tensor([true_emotion_idx], device=logits.device)
            )
            
            # Update classifier
            classifier_optimizer.zero_grad()
            classifier_loss.backward()
            classifier_optimizer.step()
            
            # Record classifier prediction results
            classifier_pred = torch.argmax(logits, dim=1).item()
            classifier_preds.append(classifier_emotion_id_to_name[classifier_pred])
            classifier_true_labels.append(emotion_id_to_name[true_emotion_id])
            
            # Only record detailed information at logging_steps
            if step % args.logging_steps == 0:
                # Calculate current classifier accuracy
                classifier_correct = sum(p == l for p, l in zip(classifier_preds, classifier_true_labels))
                classifier_accuracy = classifier_correct / len(classifier_preds)
                
                with open(classifier_train_log, "a") as f:
                    f.write(f"Epoch {epoch}, Step {step}:\n")
                    f.write(f"Classifier Loss: {classifier_loss.item():.4f}\n")
                    f.write(f"Classifier Accuracy: {classifier_accuracy:.4f}\n")
                    f.write("-" * 50 + "\n")
        
        # Calculate average loss
        loss = batch_loss / 2
        
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        
        total_loss += loss.detach().float()  # Accumulate to total_loss
        step_losses.append(loss.item())
        
        # Update progress bar, display current loss and learning rate
        current_lr = optimizer.param_groups[0]["lr"]
        progress_bar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "lr": f"{current_lr:.2e}"
        })
        progress_bar.update(1)
        
        # Log every logging_steps
        if step > 0 and step % args.logging_steps == 0:
            # Calculate average loss for recent logging_steps
            recent_losses = step_losses[-args.logging_steps:]
            avg_loss = sum(recent_losses) / len(recent_losses)
            
            # Calculate current emotion prediction accuracy
            correct = sum(p == l for p, l in zip(emotion_preds, true_emotion_labels))
            accuracy = correct / len(emotion_preds) if emotion_preds else 0
            
            logger.info(f"Epoch {epoch}, Step {step}:")
            logger.info(f"Average Loss (last {args.logging_steps} steps) = {avg_loss:.4f}, lr = {current_lr:.2e}")
            logger.info(f"Emotion Prediction Accuracy: {accuracy:.4f}")
            
            # Print recent prediction results
            recent_preds = list(zip(emotion_preds, true_emotion_labels))[-5:]  # Recent 5 predictions
            logger.info("Recent predictions:")
            for pred, label in recent_preds:
                logger.info(f"Predicted: {pred}, True: {label}")
            
            # Write to log file
            with open(log_file, "a") as f:
                f.write(f"\nEpoch {epoch}, Step {step}:\n")
                f.write(f"Average Loss (last {args.logging_steps} steps): {avg_loss:.4f}\n")
                f.write(f"Learning Rate: {current_lr:.2e}\n")
                f.write(f"Accuracy: {accuracy:.4f}\n")
                f.write("Recent predictions:\n")
                for pred, label in recent_preds:
                    f.write(f"Predicted: {pred}, True: {label}\n")

    
    progress_bar.close()
    log_file
    # Print summary at the end of epoch
    if emotion_preds:
        final_correct = sum(p == l for p, l in zip(emotion_preds, true_emotion_labels))
        final_accuracy = final_correct / len(emotion_preds)
        
        # Console output
        logger.info(f"Epoch {epoch} completed:")
        logger.info(f"Average loss: {total_loss / len(train_dataloader):.4f}")
        logger.info(f"Final emotion prediction accuracy: {final_accuracy:.4f}")
        
        # Write to log file
        with open(log_file, "a") as f:
            f.write(f"\nEpoch {epoch} Summary:\n")
            f.write(f"Average loss: {total_loss / len(train_dataloader):.4f}\n")
            f.write(f"Final emotion prediction accuracy: {final_accuracy:.4f}\n")
            f.write("="*50 + "\n")  # Add separator line
    
    # Add classifier summary at the end of epoch
    if classifier_preds:
        classifier_final_correct = sum(p == l for p, l in zip(classifier_preds, classifier_true_labels))
        classifier_final_accuracy = classifier_final_correct / len(classifier_preds)
        
        # Console output
        logger.info(f"Classifier final accuracy: {classifier_final_accuracy:.4f}")
        
        # Write to log file
        with open(log_file, "a") as f:
            f.write(f"Classifier final accuracy: {classifier_final_accuracy:.4f}\n")
    
    return total_loss / len(train_dataloader), step_losses, classifier

def verify_model_consistency(model, epoch, stage="before"):
    """Verify model parameter consistency"""
    # Create verification log file
    verify_log_file = os.path.join(args.output_dir, "model_verification.log")
    
    # Get model parameter snapshot
    model_state = {}
    for name, param in model.named_parameters():
        # Only verify non-LoRA parameters
        if True:
            # Use parameter statistics as fingerprints
            with torch.no_grad():
                model_state[name] = {
                    "mean": param.mean().item(),
                    "std": param.std().item(),
                    "norm": param.norm().item(),
                    "shape": tuple(param.shape),
                    "requires_grad": param.requires_grad
                }
    
    # Record verification results
    with open(verify_log_file, "a") as f:
        f.write(f"\nEpoch {epoch} - {stage} evaluation:\n")
        f.write("="*50 + "\n")
        for name, stats in model_state.items():
            f.write(f"Layer: {name}\n")
            for stat_name, value in stats.items():
                f.write(f"  {stat_name}: {value}\n")
        f.write("\n")
    
    return model_state

def compare_model_states(state1, state2, epoch):
    """Compare whether two model states are consistent"""
    verify_log_file = os.path.join(args.output_dir, "model_verification.log")
    is_consistent = True
    
    with open(verify_log_file, "a") as f:
        f.write(f"\nEpoch {epoch} - Model Consistency Check:\n")
        f.write("="*50 + "\n")
        
        for name in state1.keys():
            if name not in state2:
                f.write(f"❌ Layer {name} missing in second state!\n")
                is_consistent = False
                continue
                
            for stat_name in state1[name].keys():
                val1 = state1[name][stat_name]
                val2 = state2[name][stat_name]
                
                if isinstance(val1, (float, int)):
                    # For numerical values, check if they are exactly equal
                    if abs(val1 - val2) > 1e-6:
                        f.write(f"❌ Mismatch in {name} - {stat_name}: {val1} != {val2}\n")
                        is_consistent = False
                else:
                    # For other types (e.g., shape), check if they are exactly equal
                    if val1 != val2:
                        f.write(f"❌ Mismatch in {name} - {stat_name}: {val1} != {val2}\n")
                        is_consistent = False
        
        if is_consistent:
            f.write("✓ Model states are consistent!\n")
        else:
            f.write("❌ Model states are different!\n")
        f.write("\n")
    
    return is_consistent

class EmotionClassifier(torch.nn.Module):
    def __init__(self, embedding_dim, num_classes=4, num_transformer_layers=1):
        super().__init__()
        
        # Define transformer encoder layer configuration
        transformer_layer = torch.nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=8,  # Number of heads in multi-head attention
            dim_feedforward=2 * embedding_dim,  # Hidden dimension of FFN, usually 4 times d_model
            dropout=0.1,
            activation='gelu',  # Use GELU activation function
            batch_first=True,  # Input tensor shape is (batch, seq, feature)
            dtype=torch.bfloat16  # Set to bfloat16 type
        )
        
        # Create transformer encoder
        self.transformer_encoder = torch.nn.TransformerEncoder(
            transformer_layer,
            num_layers=num_transformer_layers
        )
        
            # Classifier part also uses bfloat16
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, 256, dtype=torch.bfloat16),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(256, num_classes, dtype=torch.bfloat16)
        )
        
    def forward(self, embeddings):
        # embeddings shape: [batch_size, seq_len, embedding_dim]
        
        # Pass through transformer encoder
        transformer_output = self.transformer_encoder(embeddings)
        
        # Average pool over sequence dimension
        pooled = torch.mean(transformer_output, dim=1)
        
        # Pass through classifier
        return self.classifier(pooled)

def evaluate(model: PreTrainedModel, eval_dataloader: DataLoader, accelerator: Accelerator, epoch: int) -> Tuple[float, float, float]:
    # Verify model state before evaluation
    initial_state = verify_model_consistency(model, epoch, "before")
    
    # Create two log files, one for original model and one for temporary LoRA
    base_log_file = os.path.join(args.output_dir, "eval_logs_base.txt")

    try:
        # 1. Copy original model and move it to CPU
        logger.info("Creating model copy for evaluation...")
        with torch.no_grad():
            eval_model = copy.deepcopy(model)
            model.to('cpu')  # Move original model to CPU to save memory
        
        # 2. Merge LoRA weights on the copied model
        logger.info("Merging LoRA weights for evaluation...")
        eval_model.merge_and_unload()
        
        # 3. Freeze all parameters
        for param in eval_model.parameters():
            param.requires_grad = False
            
        eval_model.eval()
        # Store predictions from both methods
        base_preds, base_labels = [], []
        templora_preds, templora_labels = [], []
        
        temp_lora = None
        end_last = 0  # Record the end position of the last sentence
        audio_counter = 0  # Add audio counter
        prefix_length = 768  # Set history context length
        
        # Emotion token ID mapping
        emotion_id_to_name = {
            tokenizer.convert_tokens_to_ids("<emo_hap>"): "happy",
            tokenizer.convert_tokens_to_ids("<emo_sad>"): "sad",
            tokenizer.convert_tokens_to_ids("<emo_ang>"): "angry",
            tokenizer.convert_tokens_to_ids("<emo_neu>"): "neutral"
        }

        classifier_emotion_id_to_name = {
            0: "happy",
            1: "sad",
            2: "angry",
            3: "neutral"
        }
        emotion_ids = list(emotion_id_to_name.keys())
        end_token_id = 38565  # Audio end token ID

        # Create classifier evaluation log file
        classifier_eval_log = os.path.join(args.output_dir, "classifier_eval_logs.txt")
        
        # Store classifier predictions
        classifier_preds = []
        classifier_true_labels = []
        
            # Create long audio prediction log file
        long_audio_log = os.path.join(args.output_dir, "long_audio_predictions.txt")
        
        # Store the last prediction of the current epoch's long audio
        long_audio_results = {
            'classifier': {'preds': [], 'labels': []},
            'base': {'preds': [], 'labels': []},
            'templora': {'preds': [], 'labels': []}
        }
        
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            emotion_batch = batch
            emotion_batch = to_device(emotion_batch, accelerator.device)
            
            input_ids = emotion_batch["input_ids"].squeeze(0)
            if end_token_id in input_ids:
                # Find the end token position
                end_pos_list = torch.where(input_ids == end_token_id)[1]  # Use [1] to get column index
                end_pos = end_pos_list[-1].item()
                # Only use input up to the end token
                eval_input_ids = input_ids[:, :end_pos+1]
                
                with torch.no_grad(), eval_model.disable_adapter():  # Disable LoRA adapter
                    # Get embedding layer output
                    embeddings = eval_model.get_input_embeddings()(eval_input_ids)
                    
                    # Use classifier for prediction
                    classifier_logits = classifier(embeddings)
                    classifier_pred = torch.argmax(classifier_logits, dim=1).item()
                    true_emotion_idx = emotion_ids.index(input_ids[0, end_pos + 1].item())
                    # Store prediction results
                    classifier_preds.append(classifier_emotion_id_to_name[classifier_pred])
                    classifier_true_labels.append(classifier_emotion_id_to_name[true_emotion_idx])
                    
                    # Original model prediction
                    outputs = eval_model(
                        input_ids=eval_input_ids,
                        attention_mask=torch.ones_like(eval_input_ids)
                    )
                    
                    # Get prediction at the last position
                    last_token_logits = outputs.logits[:, -1, :]
                    emotion_logits = last_token_logits[0, emotion_ids]
                    pred_id = emotion_ids[torch.argmax(emotion_logits).item()]
                    true_id = input_ids[0, end_pos + 1].item()
                    
                    # Record original model and classifier prediction results
                    base_preds.append(emotion_id_to_name.get(pred_id, f"unknown_{pred_id}"))
                    base_labels.append(emotion_id_to_name.get(true_id, f"unknown_{true_id}"))
            
            # 2. Use temporary LoRA for prediction

            is_new_audio = len(end_pos_list) == 1
            
            if is_new_audio:
                audio_counter += 1  # Update audio counter
                # Clean up old temporary LoRA and optimizer
                if temp_lora is not None:
                    del temp_lora
                    del temp_optimizer
                    torch.cuda.empty_cache()
                
                # Create new temporary LoRA, only for Linear layer
                lora_config = LoraConfig(
                    task_type=TaskType.CAUSAL_LM,
                    inference_mode=False,
                    r=64,
                    lora_alpha=64,
                    lora_dropout=0.05,
                    # Only select Linear layer as target
                    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
                )
                temp_lora = get_peft_model(eval_model, lora_config)
                temp_lora.to(accelerator.device)
                
                # Create dedicated optimizer for temporary LoRA
                temp_optimizer = torch.optim.AdamW(
                    temp_lora.parameters(),
                    lr=args.templora_inference_lr,  # Use learning rate from command line parameters
                    weight_decay=0
                )
                end_last = 0
        
            
            if end_last > 0:
                # Prepare training data for temporary LoRA (autoregressive way)
                context_start = max(0, end_last - prefix_length)
                train_input_ids = input_ids[:, context_start:end_pos]  # Input does not include the last token
                # Create deep copy of train_labels to avoid modifying original input_ids
                train_labels = input_ids[:, context_start+1:end_pos + 1].clone()     # Use clone() to create a copy
                
                train_attention_mask = torch.ones_like(train_input_ids)
                
                # Set labels for history context part to -100
                history_length = end_last - context_start
                if history_length > 0:
                    train_labels[:,:history_length] = -100
                
                # Train temporary LoRA (autoregressive way)
                temp_lora.train()
                outputs = temp_lora(
                    input_ids=train_input_ids,
                    attention_mask=train_attention_mask,
                    labels=train_labels
                )
                
                loss = outputs.loss
                loss.backward()
                temp_optimizer.step()
                temp_optimizer.zero_grad()
                
                # Record training loss of temporary LoRA
                templora_log_file = os.path.join(args.output_dir, "templora_training_logs.txt")
                with open(templora_log_file, "a") as f:
                    f.write(f"Epoch {epoch}, Audio #{audio_counter}, ")
                    f.write(f"Context [{context_start}:{end_pos}], ")
                    f.write(f"Loss: {loss.item():.4f}\n")
                
                # Also output on console
                logger.info(f"Temp LoRA training - Audio #{audio_counter}, Loss: {loss.item():.4f}")
            
            # Update end_last
            end_last = end_pos
            
            # Use temporary LoRA for prediction (autoregressive way)
                
            # Only use input up to the end token for prediction
            context_start = max(0, end_pos - prefix_length)
            eval_input_ids = input_ids[:, context_start:end_pos+1]
            
            outputs = temp_lora(
                input_ids=eval_input_ids,
                attention_mask=torch.ones_like(eval_input_ids)
            )
            
            # Get prediction at the last position for emotion prediction
            last_token_logits = outputs.logits[0, -1, :]
            emotion_logits = last_token_logits[emotion_ids]
            pred_id = emotion_ids[torch.argmax(emotion_logits).item()]
            true_id = input_ids[:, end_pos + 1].item()  # True emotion label
            
            templora_preds.append(emotion_id_to_name.get(pred_id, f"unknown_{pred_id}"))
            templora_labels.append(emotion_id_to_name.get(true_id, f"unknown_{true_id}"))

            if is_new_audio and len(classifier_preds) > 1:  # Ensure there are prediction results
                # Record the last prediction of the previous audio
                long_audio_results['classifier']['preds'].append(classifier_preds[-2])
                long_audio_results['classifier']['labels'].append(classifier_true_labels[-2])
                long_audio_results['base']['preds'].append(base_preds[-2])
                long_audio_results['base']['labels'].append(base_labels[-2])
                long_audio_results['templora']['preds'].append(templora_preds[-2])
                long_audio_results['templora']['labels'].append(templora_labels[-2])

        # Ensure recording the last audio result
        if len(classifier_preds) > 0:
            long_audio_results['classifier']['preds'].append(classifier_preds[-1])
            long_audio_results['classifier']['labels'].append(classifier_true_labels[-1])
            long_audio_results['base']['preds'].append(base_preds[-1])
            long_audio_results['base']['labels'].append(base_labels[-1])
            long_audio_results['templora']['preds'].append(templora_preds[-1])
            long_audio_results['templora']['labels'].append(templora_labels[-1])
        
        # Record long audio prediction results and accuracy at the end of the epoch
        with open(long_audio_log, "a") as f:
            f.write(f"\nEpoch {epoch} - Long Audio Predictions:\n")
            f.write("="*50 + "\n")
            
            # Calculate and record accuracy for each method
            for method, results in long_audio_results.items():
                accuracy = accuracy_score(results['labels'], results['preds'])
                f.write(f"\n{method.capitalize()} Results:\n")
                f.write(f"Accuracy: {accuracy:.4f}\n")
                f.write("Predictions:\n")
                
                # Record prediction details for each long audio
                for idx, (pred, label) in enumerate(zip(results['preds'], results['labels']), 1):
                    f.write(f"Audio {idx:3d} - Predicted: {pred:10s} | True: {label:10s} | "
                           f"{'✓' if pred == label else '✗'}\n")
            
            f.write("\n" + "="*50 + "\n")
            
            # Also output results on console
            logger.info(f"Long Audio Results for Epoch {epoch}:")
            for method, results in long_audio_results.items():
                accuracy = accuracy_score(results['labels'], results['preds'])
                logger.info(f"{method.capitalize()} Accuracy: {accuracy:.4f}")

    finally:
        # Clean up resources
        with torch.no_grad():
            if temp_lora is not None:
                del temp_lora
                del temp_optimizer
            
            # Delete model copy for evaluation
            del eval_model
            torch.cuda.empty_cache()
            
            # Move original model back to GPU
            model.to(accelerator.device)
        
        # Verify model state after evaluation
        final_state = verify_model_consistency(model, epoch, "after")
        
        # Compare model state before and after evaluation
        is_consistent = compare_model_states(initial_state, final_state, epoch)
        if not is_consistent:
            logger.warning("Model parameters changed during evaluation!")

    # Calculate and record evaluation metrics for both methods
    def calculate_metrics(preds, labels, log_file, method_name):
        accuracy = accuracy_score(labels, preds)
        
            # Create log file for prediction results
        predictions_log_file = os.path.join(os.path.dirname(log_file), f"{method_name}_predictions.log")
        with open(predictions_log_file, "a") as f:
            f.write(f"\nEpoch {epoch} - {method_name} Predictions:\n")
            f.write("="*50 + "\n")
            for idx, (pred, label) in enumerate(zip(preds, labels), 1):
                f.write(f"Sample {idx:4d} - Predicted: {pred:10s} | True: {label:10s} | "
                        f"{'✓' if pred == label else '✗'}\n")
            f.write("\n")
        
        class_accuracies = {}
        for emotion in set(labels):
            mask = [l == emotion for l in labels]
            class_preds = [p for p, m in zip(preds, mask) if m]
            class_labels = [l for l, m in zip(labels, mask) if m]
            if class_labels:
                class_accuracies[emotion] = accuracy_score(class_labels, class_preds)
        ua = sum(class_accuracies.values()) / len(class_accuracies)
        
        wf1 = f1_score(labels, preds, average='weighted')
        
        # Record log
        with open(log_file, "a") as f:
            f.write(f"\nEpoch {epoch} - {method_name}:\n")
            f.write(f"Accuracy: {accuracy:.4f}\n")
            f.write(f"UA: {ua:.4f}\n")
            f.write(f"WF1: {wf1:.4f}\n")
            f.write("Class-wise accuracies:\n")
            for emotion, acc in class_accuracies.items():
                f.write(f"{emotion}: {acc:.4f}\n")
            f.write("Predictions:\n")
            for pred, label in zip(preds, labels):
                f.write(f"Predicted: {pred}, True: {label}\n")
        
        logger.info(f"Epoch {epoch} - {method_name}:")
        logger.info(f"Accuracy: {accuracy:.4f}")
        logger.info(f"UA: {ua:.4f}")
        logger.info(f"WF1: {wf1:.4f}")
        logger.info("Class-wise accuracies:")
        for emotion, acc in class_accuracies.items():
            logger.info(f"{emotion}: {acc:.4f}")
            
        return accuracy, ua, wf1

    # Calculate metrics for both methods
    base_metrics = calculate_metrics(base_preds, base_labels, base_log_file, "Base Model")
    templora_metrics = calculate_metrics(templora_preds, templora_labels, templora_log_file, "Temp LoRA")
    # After evaluation, use the same calculate_metrics function to process classifier results
    classifier_metrics = calculate_metrics(
        classifier_preds, 
        classifier_true_labels, 
        classifier_eval_log, 
        "Classifier"
    )
    return templora_metrics

def parse_args() -> TrainingArguments:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default='your_llama2_32k_model_path')
    parser.add_argument("--train_csv", type=str, default='your_iemocap_data_path/toy_pro.csv')
    parser.add_argument("--eval_csv", type=str, default='your_iemocap_data_path/toy_pro.csv')
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=int, default=64)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--lr_scheduler_type", type=str, default="constant_with_warmup")
    parser.add_argument("--optim", type=str, default="torch_adaw_fused")
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--output_dir", type=str, default='your_output_logs_path/toy_5e-5')
    parser.add_argument("--training_input_length", type=int, default=1024)
    parser.add_argument("--stride_size", type=int, default=256) # In our code, this parameter does not work, we step through with the audio end token, not fixed
    parser.add_argument("--num_train_epochs", type=int, default=20)
    parser.add_argument("--gradient_checkpointing", type=str, default="false")
    parser.add_argument("--use_flash_attention_2", action="store_true")
    parser.add_argument("--logging_steps", type=int, default=100)
    parser.add_argument("--end_token_weight", type=float, default=1)
    parser.add_argument("--plot_loss_curve", default=True)
    parser.add_argument("--classifier_learning_rate", type=float, default=1e-4,
                       help="Learning rate for the emotion classifier")
    parser.add_argument("--templora_inference_lr", type=float, default=5e-5,
                       help="Learning rate for Temp LoRA during inference")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed for reproducibility")
    args = parser.parse_args()
    return TrainingArguments(**{f.name: getattr(args, f.name) for f in fields(TrainingArguments)})

if __name__ == '__main__':
    if torch.cuda.is_available():
        current_device = torch.cuda.current_device()
        print(f"Current GPU ID: {current_device}")
    args = parse_args()
    
    # Set all related random seeds
    def set_all_seeds(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    # Use seed value from parameters
    set_all_seeds(args.seed)
    
    accelerator = Accelerator(
        log_with="tensorboard",
        project_dir=args.output_dir,
        gradient_accumulation_steps=1
    )
    
    accelerator.wait_for_everyone()
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name,
        trust_remote_code=True,
        use_fast=False
    )
    tokenizer.padding_side = "right"

    # Add 6561 audio tokens + 4 emotion tokens + 1 audio end token
    num_audio_tokens = 6561
    emotion_tokens = ["<emo_hap>", "<emo_sad>", "<emo_ang>", "<emo_neu>"]
    end_token = ["<audio_end>"]
    
    # Add different types of tokens
    audio_tokens = [f"<audio_{i}>" for i in range(num_audio_tokens)]
    new_tokens = audio_tokens + emotion_tokens + end_token
    tokenizer.add_tokens(new_tokens)
    
    
    model_kwargs = {
        "pretrained_model_name_or_path": args.model_name,
        "trust_remote_code": True,
        "use_cache": False if args.use_flash_attention_2 else None,
        "torch_dtype": torch.bfloat16,
        "device_map": "cuda"
    }
    if args.use_flash_attention_2:
        model_kwargs["attn_implementation"] = "flash_attention_2"
        
    model = AutoModelForCausalLM.from_pretrained(**model_kwargs)

        # Adjust model embedding layer to match new vocabulary size
    model.resize_token_embeddings(len(tokenizer))
    
    if args.gradient_checkpointing == "true":
        logger.info("Enabling gradient checkpointing")
        model.gradient_checkpointing_enable()
        
    # Initialize LoRA
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=find_all_linear_names(model=model)
    )
    model.enable_input_require_grads()
    model = get_peft_model(model=model, peft_config=lora_config)
    model.print_trainable_parameters()
    
    # Initialize datasets
    train_dataset = AudioDatasetFromCSV(
        csv_path=args.train_csv,
        tokenizer=tokenizer,
        prefix_length=args.training_input_length,
        stride_size=args.stride_size,
        is_eval=False
    ) 
    eval_dataset = AudioDatasetFromCSV(
        csv_path=args.eval_csv,
        tokenizer=tokenizer,
        prefix_length=args.training_input_length,
        stride_size=args.stride_size,
        is_eval=True
    ) #Contains four emotion long audio data, about twenty pieces
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize optimizer and scheduler
    optimizer = create_optimizer(model=model, args=args)
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    lr_scheduler = create_lr_scheduler(
        num_training_steps=max_train_steps,
        optimizer=optimizer,
        args=args
    )
    
    # Prepare everything with accelerator
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    
    # Training loop
    best_ua = float('-inf')
    best_model_path = None
    all_train_losses = []
    all_eval_losses = []
    
    logger.info("Starting training...")
    for epoch in range(args.num_train_epochs):
        train_loss, step_losses, classifier = train_epoch(
            model=model,
            train_dataloader=train_dataloader,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            accelerator=accelerator,
            epoch=epoch,
            args=args
        )
        all_train_losses.extend(step_losses)
        
        accuracy, ua, wf1 = evaluate(model, eval_dataloader, accelerator, epoch)
        logger.info(f"Epoch {epoch}: train_loss = {train_loss:.4f}")
        logger.info(f"Metrics - Accuracy: {accuracy:.4f}, UA: {ua:.4f}, WF1: {wf1:.4f}")
    
        # Plot loss curve at the end of training
    if args.plot_loss_curve and accelerator.is_local_main_process:
        plt.figure(figsize=(10, 6))
        
        # Convert loss data to CPU
        train_losses = [loss if isinstance(loss, float) else loss.cpu().item() 
                       for loss in all_train_losses]
        
        # Calculate moving average
        window_size = 50
        smoothed_losses = pd.Series(train_losses).rolling(window_size).mean()
        
        # Plot curve
        plt.plot(train_losses, 'lightgray', alpha=0.3, label='origin Loss')
        plt.plot(smoothed_losses, 'b', label='smooth Loss', linewidth=2)
        
        plt.title('train Loss')
        plt.xlabel('step')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Save chart
        plt.savefig(os.path.join(args.output_dir, 'loss_curve.png'), dpi=300, bbox_inches='tight')
        plt.close()
