import argparse
import json
import os
from typing import List, Any
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, LightningLoggerBase, TensorBoardLogger
from tqdm import tqdm
from transformers import AutoTokenizer, AutoConfig, AdamW
import pytorch_lightning as pl
import torch.nn as nn
import re
from evaluate import evaluate_metrics
from modeling_t5 import T5ForConditionalGeneration  # Custom T5 model with dual LoRA support
from copy_data_loader import prepare_data
from datetime import datetime
import gc
#from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

def parse_args():

    parser = argparse.ArgumentParser(description="Dialogue state tracking model training parameters")
    
    # Model and training basic configuration
    parser.add_argument("--ckpt", type=str, default="save/models/small", help="Pre-trained model checkpoint path")
    parser.add_argument("--train_batch_size", type=int, default=8, help="Training batch size")
    parser.add_argument("--worker_number", type=int, default=8, help="CPU threads for data loader, Windows systems suggest setting to 0")
    parser.add_argument("--dev_batch_size", type=int, default=8, help="Validation batch size")
    parser.add_argument("--test_batch_size", type=int, default=8, help="Test batch size")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay coefficient")
    parser.add_argument("--seed", type=int, default=3407, help="Random seed")
    
    # Dataset and domain configuration
    parser.add_argument("--dataset", type=str, default="multiwoz", help="Dataset name")
    parser.add_argument("--except_domain", type=str, default="none", help="Excluded domain (hotel/train/restaurant/attraction/taxi)")
    parser.add_argument("--only_domain", type=str, default="none", help="Only use this domain")
    parser.add_argument("--slot_lang", type=str, default="question", help="Slot description type (none/human/naive/value/question/slottype)")
    parser.add_argument("--max_size", type=int, default=250, help="Maximum tokens for model input")
    parser.add_argument("--fewshot", type=float, default=0.0, help="Data proportion for few-shot experiments")
    
    # Model training control
    parser.add_argument("--no_freeze", action='store_true', help="Do not freeze any parameters")
    parser.add_argument("--test", action="store_true", help="Run test mode")
    parser.add_argument("--ckpt_best", type=str, help="Best checkpoint path for testing")
    parser.add_argument("--model_name", type=str, default="t5", help="Base model to use (t5/bart)")
    parser.add_argument("--no_early_stop", action='store_true', help="Disable early stopping")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
    parser.add_argument("--warm_up_steps", type=int, default=3000, help="Warmup steps")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=16, help="Gradient accumulation steps")
    parser.add_argument("--n_epochs", type=int, default=5, help="Training epochs")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping norm")
    
    # LoRA configuration
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA scaling factor")
    parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
    parser.add_argument("--full_frozen", type=bool, default=False, help="Whether to fully freeze base model")
    
    # Prompt configuration
    parser.add_argument("--apply_prompt", type=bool, default=True, help="Whether to apply prompt mechanism")
    parser.add_argument("--use_prompt", action="store_true", help="Use prompt")
    parser.add_argument("--fusion_method", default="mean", type=str, help="Prompt fusion method")
    parser.add_argument("--zero_initialization", default='linear', type=str, help="Initialization method")
    parser.add_argument("--reparametrization_ratio", type=float, default=1, help="Reparameterization ratio")
    parser.add_argument("--add_reparameterization", action="store_true", 
                       help="Add reparameterization trick (only recommended for SGD dataset)")
    
    # Attention head control
    parser.add_argument("--q", action="store_true", help="Apply to Query")
    parser.add_argument("--k", action="store_true", help="Apply to Key")
    parser.add_argument("--v", action="store_true", help="Apply to Value")
    parser.add_argument("--o", action="store_true", help="Apply to Output")
    
    # Weights & Biases configuration
    # parser.add_argument("--wandb_project_name", type=str, default="Zero_Shot_DST_T5DST_MultiWOZ_2_1_v2",
    #                     help="W&B project name")
    # parser.add_argument("--wandb_job_type", type=str, default="train", help="W&B job type")
    # parser.add_argument("--wandb_run_name", type=str, default="t5_run", help="W&B run name")
    # parser.add_argument("--wandb_group_name", type=str, default="Standard", help="W&B experiment group name")
    # parser.add_argument("--wandb_mode", type=str, default="offline", help="W&B run mode")
    parser.add_argument("--tb_savedir", type=str, default="tensorboard", help="Tensorboard log directory")
    
    # Other configuration
    parser.add_argument("--min_delta", type=float, default=0.0, help="Early stopping minimum delta threshold")
    parser.add_argument("--patience", type=int, default=5, help="Early stopping patience")
    parser.add_argument("--desc", default="none", type=str, help="Experiment description")
    parser.add_argument("--saving_dir", type=str, default="save", help="Model save path")
    parser.add_argument("--fix_label", action='store_true', help="Fix labels")

    args = parser.parse_args()
    return args
    
def rename_weights(key):

    # Regular expression pattern: match attention weight keys containing q/k/v/o
    # Group description:
    # group1: prefix part (e.g., "layer.0.attention.")
    # group2: attention component identifier (q/k/v/o, can be combined)
    # group3: suffix part (fixed as ".weight")
    match = re.match(r'^(.*\.)([qkvo]+)(\.weight)$', key)
    if match:
        prefix = match.group(1)    # Extract prefix part
        letters = match.group(2)   # Extract attention component identifier (q/k/v/o)
        suffix = match.group(3)    # Extract suffix part
        # Construct new key name, insert ".Con" between component identifier and suffix (indicating custom Con layer)
        new_name = f'{prefix}{letters}.Con{suffix}'
        return new_name
    else:
        # Return original key name when no match
        return key


class DST(pl.LightningModule):

    def __init__(self, args):

        super().__init__()
        self.ckpt = args.ckpt
        self.args = args
        self.except_domain = args.except_domain

        # Load and update model configuration        
        self.config = AutoConfig.from_pretrained(self.ckpt)
        self.config.update({
            "q": args.q,
            "k": args.k,
            "v": args.v,
            "o": args.o,
            "lora_r": args.lora_r,
            "lora_alpha": args.lora_alpha,
            "lora_dropout": args.lora_dropout,
            "apply_prompt": args.apply_prompt
        })

        self.weight_decay = args.weight_decay
        self.prefix_length = 10

        # Initialize tokenizer, add special tokens
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.ckpt,
            bos_token="[bos]",  
            eos_token="[eos]",  
            sep_token="[sep]"   
            
        )

        # Load pre-trained model
        self.model = T5ForConditionalGeneration.from_pretrained(self.ckpt, args=args)
        
        # Load weights and rename to adapt dual LoRA
        t5_weight = torch.load(r"save/models/small/pytorch_model.bin", weights_only=False)
        state_dict = {rename_weights(key): value for key, value in t5_weight.items()}
        self.model.load_state_dict(state_dict=state_dict, strict=False)
        
        # Adjust embedding size to accommodate new tokens
        self.model.resize_token_embeddings(new_num_tokens=len(self.tokenizer))

        # Training control parameters
        self.ff = args.full_frozen  # Whether to fully freeze model
        self.phase = 1  # Training phase
        self.no_freeze = args.no_freeze  # Whether not to freeze any parameters
        self.warm_up_steps = args.warm_up_steps  # Warmup steps
        self.lr = args.lr  # Learning rate
        self.save_path = ""  # Save path

        # Global prompt related parameters
        self.global_prompts_set = False  # Whether global prompts are set
        self.final_global_prompt = None
        
        # Initialize global prompt parameters
        # Low-dimensional global prompts for reparameterization
        self.global_prompt = torch.nn.Parameter(
            data=torch.rand(
                self.prefix_length,
                self.model.config.d_model // args.reparametrization_ratio
            ),
            requires_grad=False
        )
        
        # Reparameterization linear layer, maps low-dimensional prompts to model dimension
        self.reparametrizer = nn.Linear(
            in_features=self.model.config.d_model // args.reparametrization_ratio,
            out_features=self.model.config.d_model
        )
        
        # Final global prompts to use
        self.final_global_prompt = torch.nn.Parameter(
            data=torch.zeros(self.prefix_length, self.model.config.d_model),
            requires_grad=False
        )

        # Attention projection layers for global prompts
        self.g_q = nn.Linear(self.model.config.d_model, self.model.config.d_model, bias=False)  # Query projection
        self.g_k = nn.Linear(self.model.config.d_model, self.model.config.d_model, bias=False)  # Key projection
        self.g_v = nn.Linear(self.model.config.d_model, self.model.config.d_model, bias=False)  # Value projection
        self.g_o = nn.Linear(self.model.config.d_model, self.model.config.d_model, bias=False)  # Output projection

        # Cross-attention mechanism for interaction between global prompts and input
        self.cross_attn = torch.nn.MultiheadAttention(
            embed_dim=self.model.config.hidden_size, 
            num_heads=self.model.config.num_heads, 
            batch_first=True
        )

        # Add learnable gate parameters, initial value 2/3
        self.gate = nn.Parameter(torch.tensor(2.0/3.0), requires_grad=True)

        self.train_bs = args.train_batch_size  # Training batch size
        self.add_reparameterization = args.add_reparameterization  # Whether to add reparameterization
        self.scaling = self.args.lora_alpha / self.args.lora_r  # LoRA scaling factor


    def _precompute_lora_bias(self, global_prompt):
        p_bias = []
        # Iterate through all modules in the model to find those with SemAdapt-LoRA parameters
        for module in self.model.modules():
            # Check if module has SemAdapt-LoRA A and B matrix lists
            if hasattr(module, 'lora_SemAdapt_A') and hasattr(module, 'lora_SemAdapt_B'):
                layer_contribution = 0.0
                # Iterate through all A and B matrix combinations to calculate total contribution for this layer
                for a in module.lora_SemAdapt_A:
                    for b in module.lora_SemAdapt_B:
                        # Calculate contribution for single A-B pair using current gate parameter
                        contribution = (global_prompt @ a.transpose(0, 1) @ b.transpose(0, 1)) * self.scaling * (1 - self.gate)
                        layer_contribution += contribution
                # Add total contribution of this layer to p_bias list
                p_bias.append(layer_contribution.tolist())
        # Convert and return precomputed bias
        return torch.tensor(p_bias).transpose(0, 1)

    def init_global_prompt(self):
        assert not self.global_prompts_set  # Ensure global prompts are not set yet

        # Concatenate first tokens from common_tokens into string
        initial_prompt = " ".join(pair[0] for pair in self.common_tokens)

        # Convert text to token IDs and move to GPU
        global_prompt_tokens = self.tokenizer(
            initial_prompt,
            padding=True,
            return_tensors="pt",
            add_special_tokens=False,
            verbose=False
        )["input_ids"][0][:self.prefix_length].cuda()
        
        # Get token embeddings
        global_prompt_data = torch.squeeze(self.model.shared(global_prompt_tokens), 0)
        
        # Update global prompt data
        self.final_global_prompt.data = (global_prompt_data)
        
        # Mark global prompts as set
        self.global_prompts_set = True

    def training_step(self, batch, batch_idx):
        # If not using reparameterization and global prompts not set, initialize global prompts
        if not self.add_reparameterization and not self.global_prompts_set:
            self.init_global_prompt()

        # Set model to training mode
        self.model.train()

        # Pass gate parameter to all ConLoRALinear layers
        for module in self.model.modules():
            if hasattr(module, 'gate') and module.gate is None:
                module.gate = self.gate
        
        # Prepare encoder input and attention mask
        encoder_input = batch['fb_encoder_input'].cuda()
        encoder_attn_mask = batch["fb_encoder_attn_mask"].cuda()

        # Get slot description embeddings
        prompt_embed = self.model.shared(batch["slot_desc_input"].cuda())
        prompt_embed_mask = batch['slot_desc_attn_mask']

        # Expand global prompts to match batch size
        expanded_prompt = torch.unsqueeze(self.final_global_prompt, 1).expand(-1, self.train_bs, -1)
        expanded_prompt = expanded_prompt.transpose(0, 1)

        # Calculate attention between global prompts and slot descriptions
        # Apply attention projections
        query = self.g_q(expanded_prompt.cuda())  # Global prompts as query
        key = self.g_k(prompt_embed)  # Slot descriptions as key
        value = self.g_v(prompt_embed)  # Slot descriptions as value

        # Execute cross-attention computation
        attention_out = self.cross_attn(
            query=query,
            key=key,
            value=value,
            key_padding_mask=~prompt_embed_mask.to(bool)  # Convert mask format
        )

        # Apply output projection
        attended_prompt = self.g_o(attention_out[0])

        # Execute model forward pass
        model_output = self.model(
            input_ids=encoder_input,
            attention_mask=encoder_attn_mask,
            labels=batch['decoder_output'], # Ground truth
            prompt_embed=prompt_embed,  # Slot description embeddings
            prompt_embed_mask=prompt_embed_mask,  # Slot description mask
            hidden_attention_mask=encoder_attn_mask,  # Encoder attention mask
            global_prompt=attended_prompt  # Attention-processed global prompts
        )

        # Get loss and log
        loss = model_output['loss']
        self.log("train_loss", loss.detach(), on_step=True, on_epoch=False, prog_bar=True)

        return loss

    def on_validation_epoch_start(self) -> None:
        # Tokenize validation set descriptions
        tokenized_desc = self.tokenizer(
            self.dev_desc,
            padding=True,
            return_tensors='pt'
        )
        input_ids = tokenized_desc['input_ids']

        attention_mask = tokenized_desc['attention_mask']

        # Get slot description embeddings
        prompt_embed = self.model.shared(input_ids.cuda())
        prompt_embed_mask = attention_mask.cuda()
        eos_token_id = self.tokenizer.eos_token_id
        
        # Get batch size
        batch_size = prompt_embed.size(0)

        # Expand global prompts to match batch size
        expanded_prompt = torch.unsqueeze(self.final_global_prompt, 1).expand(-1, batch_size, -1)
        expanded_prompt = expanded_prompt.transpose(0, 1)

        # Calculate attention between global prompts and slot descriptions
        query = self.g_q(expanded_prompt.cuda())  # Global prompts as query
        key = self.g_k(prompt_embed)  # Slot descriptions as key
        value = self.g_v(prompt_embed)  # Slot descriptions as value

        # Execute cross-attention computation
        attention_out = self.cross_attn(
            query=query,
            key=key,
            value=value,
            key_padding_mask=~prompt_embed_mask.to(bool)
        )

        # Apply output projection to get global prompts
        global_prompt = self.g_o(attention_out[0])
        
        # Pass gate parameter to all ConLoRALinear layers
        for module in self.model.modules():
            if hasattr(module, 'gate') and module.gate is None:
                module.gate = self.gate

        # # Get model's named parameters
        # named_state_dict = self.model.named_parameters()

        # # Collect LoRA connection layer parameters
        # lora_SemAdapt_A = []  # LoRA A matrices
        # lora_SemAdapt_B = []  # LoRA B matrices

        # # Extract LoRA connection layer parameters from model parameters
        # for name, param in named_state_dict:
        #     if "lora_SemAdapt_A" in name:
        #         lora_SemAdapt_A.append(param)
        #     if "lora_SemAdapt_B" in name:
        #         lora_SemAdapt_B.append(param)

        # for i in range(len(lora_SemAdapt_A)):
        #     # Calculate LoRA transformation: global_prompt @ A^T @ B^T with scaling applied
        #     # @ is the dedicated operator for matrix multiplication, equivalent to torch.matmul() function
        #     p_result = (global_prompt @ 
        #                lora_SemAdapt_A[i].transpose(0, 1) @ 
        #                lora_SemAdapt_B[i].transpose(0, 1)) * self.scaling * 0.5
        #     p_bias.append(p_result.tolist())

        # Call helper method to precompute LoRA bias
        self.p_bias = self._precompute_lora_bias(global_prompt)

        #print(self.p_bias)

    def validation_step(self, batch, batch_idx):
        # Set model to evaluation mode
        self.model.eval()
        
        # Prepare encoder input and attention mask
        encoder_input = batch['fb_encoder_input'].cuda()
        encoder_attn_mask = batch["fb_encoder_attn_mask"].cuda()

        # Get slot description embeddings
        prompt_embed = self.model.shared(batch["slot_desc_input"].cuda())
        prompt_embed_mask = batch['slot_desc_attn_mask'].cuda()

        # Get indices for current batch slot descriptions
        index = [self.dev_desc.index(desc) for desc in batch['slot_description']]
        
        # Use precomputed p_bias
        p_bias = self.p_bias[index].cuda()

        # Execute model forward pass
        model_output = self.model(
            input_ids=encoder_input,
            attention_mask=encoder_attn_mask,
            labels=batch['decoder_output'],
            prompt_embed=prompt_embed,
            prompt_embed_mask=prompt_embed_mask,
            hidden_attention_mask=encoder_attn_mask,
            p_bias=p_bias  # Use precomputed bias
        )

        # Get loss
        loss = model_output['loss']

        # Log validation loss
        self.log("val_loss", loss.detach(), on_step=True, on_epoch=False, prog_bar=True)  # Log per step
        self.log("val_loss", loss.detach(), on_step=False, on_epoch=True, prog_bar=True)  # Log per epoch

        return loss

    def test_epoch_end(
        self, outputs: List[Any]
    ) -> None:
        # Set save path
        prefix = "zero-shot"
        save_path = os.path.join(self.save_path, "results")
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        # Calculate accuracy for each slot
        # slot_log[0] total samples
        # slot_log[1] correct predictions
        # slot_log[2] accuracy
        for slot_log in self.slot_logger.values():
            slot_log[2] = (slot_log[1] / slot_log[0]) if slot_log[0] != 0 else 0

        # Calculate overall evaluation metrics
        joint_acc_score, F1_score, turn_acc_score = evaluate_metrics(self.predictions, self.all_slots)

        # Organize evaluation metrics
        evaluation_metrics = {
            "Joint Acc": joint_acc_score,
            "Turn Acc": turn_acc_score,
            "Joint F1": F1_score
        }
        
        # Organize slot accuracy metrics
        slot_acc_metrics = {key: value[2] for (key, value) in self.slot_logger.items()}
        
        # Log metrics to tensorboard
        self.tensorboard_logger.log_metrics(metrics=slot_acc_metrics)
        self.tensorboard_logger.log_metrics(metrics=evaluation_metrics)

        # Get current timestamp
        now = datetime.now()
        now = now.strftime("%Y-%m-%d-%H-%M")

        # Save evaluation results
        with open(os.path.join(save_path, f"{prefix}_result_{now}.json"), 'w') as f:
            json.dump(evaluation_metrics, f, indent=4)

        # Print test results
        print(f"{prefix} result:", evaluation_metrics)

        # Save prediction results
        with open(os.path.join(save_path, f"{prefix}_prediction_{now}.json"), 'w') as f:
            json.dump(self.predictions, f, indent=4)

        # Reset statistics
        self.slot_logger = {slot_name: [0, 0, 0] for slot_name in self.all_slots}
        self.predictions = {}

        # Log joint accuracy
        self.log("Joint Acc", joint_acc_score, on_step=False, on_epoch=True, prog_bar=True)


    def on_test_epoch_start(self) -> None:
        # Tokenize test set descriptions
        tokenized_desc = self.tokenizer(
            self.test_desc,
            padding=True,
            return_tensors='pt'
        )
        input_ids = tokenized_desc['input_ids']

        attention_mask = tokenized_desc['attention_mask']

        # Get slot description embeddings
        prompt_embed = self.model.shared(input_ids.cuda())
        prompt_embed_mask = attention_mask.cuda()
        eos_token_id = self.tokenizer.eos_token_id

        # Get batch size
        batch_size = prompt_embed.size(0)

        # Expand global prompts to match batch size
        expanded_prompt = torch.unsqueeze(self.final_global_prompt, 1).expand(-1, batch_size, -1)
        expanded_prompt = expanded_prompt.transpose(0, 1)

        # Calculate attention between global prompts and slot descriptions
        query = self.g_q(expanded_prompt.cuda())  # Global prompts as query
        key = self.g_k(prompt_embed)  # Slot descriptions as key
        value = self.g_v(prompt_embed)  # Slot descriptions as value

        # Execute cross-attention computation
        attention_out = self.cross_attn(
            query=query,
            key=key,
            value=value,
            key_padding_mask=~prompt_embed_mask.to(bool)
        )

        # Apply output projection to get global prompts
        global_prompt = self.g_o(attention_out[0])
        
        # Pass gate parameter to all ConLoRALinear layers
        for module in self.model.modules():
            if hasattr(module, 'gate') and module.gate is None:
                module.gate = self.gate

        # # Get model's named parameters
        # named_state_dict = self.model.named_parameters()

        # # Collect LoRA connection layer parameters
        # lora_SemAdapt_A = []  # LoRA A matrices
        # lora_SemAdapt_B = []  # LoRA B matrices

        # # Extract LoRA connection layer parameters from model parameters
        # for name, param in named_state_dict:
        #     if "lora_SemAdapt_A" in name:
        #         lora_SemAdapt_A.append(param)
        #     if "lora_SemAdapt_B" in name:
        #         lora_SemAdapt_B.append(param)

        # # Precompute LoRA bias
        # p_bias = []
        # for i in range(len(lora_SemAdapt_A)):
        #     # Calculate LoRA transformation: global_prompt @ A^T @ B^T
        #     p_result = (global_prompt @ 
        #                lora_SemAdapt_A[i].transpose(0, 1) @ 
        #                lora_SemAdapt_B[i].transpose(0, 1))
        #     p_bias.append(p_result.tolist())

        # # Convert and store precomputed bias
        # self.p_bias = torch.tensor(p_bias).transpose(0, 1)

        # Call helper method to precompute LoRA bias
        self.p_bias = self._precompute_lora_bias(global_prompt)

        #print(self.p_bias)

    def test_step(self, batch, batch_idx):

        # Set model to evaluation mode
        self.model.eval()

        # Generate dialogue state predictions
        dst_outputs = self.generate(batch)

        # Decode prediction results
        value_batch = self.tokenizer.batch_decode(dst_outputs.sequences, skip_special_tokens=True)

        # Process each prediction result
        for idx, value in enumerate(value_batch):
            # Remove leading/trailing whitespace from predictions
            value = value.strip()

            # Get dialogue ID
            dial_id = batch["ID"][idx]
            
            # Initialize dialogue prediction record
            if dial_id not in self.predictions:
                self.predictions[dial_id] = {}
                self.predictions[dial_id]["domain"] = batch["domains"][idx][0]
                self.predictions[dial_id]["turns"] = {}
            # Initialize dialogue turn prediction record
            if batch["turn_id"][idx] not in self.predictions[dial_id]["turns"]:
                self.predictions[dial_id]["turns"][batch["turn_id"][idx]] = {"turn_belief": batch["turn_belief"][idx],
                                                                        "pred_belief": []}
            # Construct slot names based on dataset type
            if self.args.dataset == "sgd":
                pred_slot = str(batch["domain"][idx]) + '-' + str(batch["slot_text"][idx])
            else:
                pred_slot = str(batch["slot_text"][idx])

            # If prediction is not "none", add to prediction belief
            if value != "none":
                self.predictions[dial_id]["turns"][batch["turn_id"][idx]]["pred_belief"].append(pred_slot + '-' + str(value))

            # Update slot accuracy statistics
            if str(value) == str(batch["value_text"][idx]):
                self.slot_logger[pred_slot][1] += 1  # hit
            self.slot_logger[pred_slot][0] += 1  # total

            if(self.args.test):
                # Encoder self-attention matrix
                # (num_layers, batch_size, num_heads, encoder_seq_len, encoder_seq_len)
                encoder_attentions = dst_outputs.encoder_attentions[0][idx][:][:][:].detach().cpu().numpy()
                avg_attn = encoder_attentions.mean(axis=0)   # Average all heads
                generate_attention_heatmap(avg_attn, 0, batch['ID'][idx]+batch['slot_text'][idx]+str(idx))
                encoder_attentions = dst_outputs.encoder_attentions[-1][idx][:][:][:].detach().cpu().numpy()
                avg_attn = encoder_attentions.mean(axis=0)   # Average all heads
                generate_attention_heatmap(avg_attn, 5, batch['ID'][idx]+batch['slot_text'][idx]+str(idx))

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        using_native_amp,
        using_lbfgs,
    ):
        # optimizer.step(closure=optimizer_closure)
        # return
        
        # If in full freeze mode (ff), execute optimization step directly
        if self.ff:
            optimizer.step(closure=optimizer_closure)
        else:
            # Two-stage training in non-full freeze mode
            if self.phase == 1:
                # Check if need to enter second stage
                if self.global_step >= self.warm_up_steps and not self.no_freeze:
                    self.phase = 2
                    print("Second phase start")

                    # Second stage: freeze most parameters
                    for param in self.model.parameters():
                        param.requires_grad = False

                    # Only keep the following parameters trainable:
                    # 1. Encoder first layer
                    for param in self.model.encoder.block[0].parameters():
                        param.requires_grad = True

                    # 2. Encoder last layer
                    for param in self.model.encoder.block[-1].parameters():
                        param.requires_grad = True

                    # 3. Language model head
                    for param in self.model.lm_head.parameters():
                        param.requires_grad = True

                    # 4. All LoRA parameters
                    for name, param in self.model.named_parameters():
                        if "lora_" in name:
                            param.requires_grad = True

                # Execute optimization step
                optimizer.step(closure=optimizer_closure)
            else:
                # Second stage: update selected parameters
                optimizer.step(closure=optimizer_closure)


    def generate(self, batch):

        # Set model to evaluation mode
        self.model.eval()
        
        # Prepare encoder input and attention mask
        encoder_input = batch['fb_encoder_input'].cuda()
        encoder_attn_mask = batch["fb_encoder_attn_mask"].cuda()

        # Get slot description embeddings
        prompt_embed = self.model.shared(batch["slot_desc_input"].cuda())
        prompt_embed_mask = batch['slot_desc_attn_mask'].cuda()
        eos_token_id = self.tokenizer.eos_token_id

        # Get indices for current batch slot descriptions
        index = [self.test_desc.index(desc) for desc in batch['slot_description']]
        
        # Use precomputed p_bias
        p_bias = self.p_bias[index].cuda()

        # Commented code is an alternative method to calculate global prompt
        # batch_size = encoder_input.shape[0]
        # expanded_prompt = torch.unsqueeze(self.final_global_prompt, 1).expand(-1, batch_size, -1)
        # expanded_prompt = expanded_prompt.transpose(0,1)
        #
        # query = self.g_q(expanded_prompt.cuda())
        # key = self.g_k(prompt_embed)
        # value = self.g_v(prompt_embed)
        #
        # attention_out = self.cross_attn(query=query,key=key,value=value,key_padding_mask=~prompt_embed_mask.to(bool))
        #
        # attended_prompt = self.g_o(attention_out[0])

        # Call model's generation function
        if self.args.test:
            self.model.config.output_attentions = True
        else:
            self.model.config.output_attentions = False
        dst_output = self.model.generate(
            input_ids=encoder_input,
            attention_mask=encoder_attn_mask,
            eos_token_id=eos_token_id,
            prompt_embed=prompt_embed,
            prompt_embed_mask=prompt_embed_mask,
            hidden_attention_mask=encoder_attn_mask,
            p_bias=p_bias,  # Use precomputed LoRA bias
            # global_prompt=attended_prompt,  # Commented out global prompt
            max_length=40,  # Maximum generation length
            # num_beams=5,  # Commented out beam search parameter
            # early_stopping=True,  # Commented out early stopping parameter
            return_dict_in_generate=self.model.config.output_attentions,
            output_attentions=self.model.config.output_attentions
        )

        return dst_output


    def configure_optimizers(self):
        return AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr,correct_bias=True,weight_decay=self.weight_decay)


def train(args):

    # Get current time as part of model save path
    from datetime import datetime
    now = datetime.now()
    now = now.strftime("%Y-%m-%d-%H-%M")

    # Initialize DST model
    model = DST(args)

    # Build model save path
    # If description provided, add it to the path
    if args.desc != "none":
        save_path = os.path.join(args.saving_dir, 
                               args.model_name + "-" + now + "-" + 
                               args.except_domain + '-' + args.desc)
    else:
        save_path = os.path.join(args.saving_dir, 
                               args.model_name + "-" + now + "-" + 
                               args.except_domain)

    # Set model save path
    model.save_path = save_path

    # Configure loggers
    # Comment out Weights & Biases logger
    # run_name = args.wandb_run_name
    # wandb_logger = WandbLogger(
    #     name=run_name,
    #     project=args.wandb_project_name,
    #     job_type=args.wandb_job_type,
    #     group=args.wandb_group_name
    # )
    # model.wandb_logger = wandb_logger
    
    # Use TensorBoard logger
    tensorboard_logger = TensorBoardLogger(
        save_dir=os.path.join(args.tb_savedir),
        name=args.except_domain+'_'+args.desc
    )
    model.tensorboard_logger = tensorboard_logger

    # Set random seed for reproducibility
    seed_everything(args.seed)

    # Create save directory if it doesn't exist
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # Save training parameters
    save_args(args, save_path)

    # Define data module class to support dynamic data generation per epoch
    class DataModule(pl.LightningDataModule):
        def __init__(self, args, tokenizer, model):
            super().__init__()
            self.args = args
            self.tokenizer = tokenizer
            self.model = model
            self.all_slots = None
            self.global_prompts = None
            self.dev_desc = None
            self.test_desc = None

        def prepare_data(self):
            pass

        def setup(self, stage=None):
            pass

        def _generate_data(self):
            # Generate new data on each call
            self.train_loader, self.val_loader, self.test_loader, \
            self.all_slots, self.global_prompts, self.dev_desc, self.test_desc = \
            prepare_data(self.args, self.tokenizer)
            # Update model attributes
            self.model.common_tokens = self.global_prompts
            self.model.all_slots = self.all_slots
            self.model.dev_desc = self.dev_desc
            self.model.test_desc = self.test_desc

        def train_dataloader(self):
            self._generate_data()
            return self.train_loader

        def val_dataloader(self):
            return self.val_loader

        def test_dataloader(self):
            return self.test_loader

    # Initialize data module
    dm = DataModule(args, model.tokenizer, model)
    dm.setup()
    dm._generate_data()  # Generate initial data to initialize model attributes

    # Initialize slot logger and prediction dictionary
    model.slot_logger = {slot_name: [0, 0, 0] for slot_name in dm.all_slots}
    model.predictions = {}

    # Configure model checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        filepath=save_path+"/{epoch}-{global_step}-{val_loss:.2f}",
        monitor='val_loss',      # Monitor validation loss
        verbose=False,           # Do not show detailed information
        save_last=True,         # Save last checkpoint
        save_top_k=1,           # Save best checkpoint
        mode="min",             # Minimize monitoring metric
    )

    # Configure early stopping callback (if enabled)
    callbacks = []
    if not args.no_early_stop:
        callbacks = [pl.callbacks.EarlyStopping(
            monitor='val_loss',           # Monitor validation loss
            min_delta=args.min_delta,     # Minimum improvement threshold
            patience=args.patience,        # Patience value
            verbose=False,                # Do not show detailed information
            mode='min'                    # Minimize monitoring metric
        )]

    # Configure PyTorch Lightning trainer
    trainer = Trainer(
        default_root_dir=save_path,                              # Default root directory
        accumulate_grad_batches=args.gradient_accumulation_steps,# Gradient accumulation steps
        gradient_clip_val=args.max_norm,                         # Gradient clipping threshold
        max_epochs=args.n_epochs,                                # Maximum training epochs
        callbacks=callbacks,                                      # Callback function list
        checkpoint_callback=checkpoint_callback,                  # Checkpoint callback
        deterministic=True,                                      # Ensure reproducibility
        num_nodes=1,                                            # Use single node
        #logger=wandb_logger,                                     # Use W&B logger
        logger=tensorboard_logger,                               # Use TensorBoard logger
        gpus=1,                                                 # Use 1 GPU
        val_check_interval=0.1,                                 # Validation check interval
    )

    # Execute training and testing
    trainer.fit(model, datamodule=dm)  # Train model using data module, data refreshed automatically each epoch
    trainer.test(model, datamodule=dm, ckpt_path='best')              # Test model using data module's test loader



def save_args(args,save_path):
    argsDict = args.__dict__
    with open(save_path + '\\args.txt', 'w') as f:
        f.writelines('------------------ start ------------------' + '\n')
        for eachArg, value in argsDict.items():
            f.writelines(eachArg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------')

def test(args):
    # Load checkpoint and run test
    model = DST(args)    

    # Add safe global variable to allow loading checkpoint
    import pytorch_lightning.callbacks.early_stopping
    torch.serialization.add_safe_globals([pytorch_lightning.callbacks.early_stopping.EarlyStopping])
    
    # Load checkpoint with weights_only=False
    checkpoint_data = torch.load(str(args.ckpt_best), weights_only=False)
    state_dict = checkpoint_data['state_dict']
    model.load_state_dict(state_dict, strict=False)
    
    # Prepare data loader and other data needed by model
    train_loader, val_loader, test_loader, all_slots, global_prompts, \
    dev_desc, test_desc = prepare_data(args, model.tokenizer)

    # Set model data attributes
    model.common_tokens = global_prompts  # Global prompts
    model.all_slots = all_slots          # All slots
    model.dev_desc = dev_desc            # Validation set descriptions
    model.test_desc = test_desc          # Test set descriptions

    # Initialize slot logger and prediction dictionary
    model.slot_logger = {slot_name: [0, 0, 0] for slot_name in all_slots}
    model.predictions = {}
    
    tensorboard_logger = TensorBoardLogger(
        save_dir=os.path.join(args.tb_savedir),
        name=args.except_domain+'_'+args.desc
    )
    model.tensorboard_logger = tensorboard_logger

    # Configure PyTorch Lightning trainer
    trainer = Trainer(
        deterministic=True,                                     # Ensure reproducibility
        num_nodes=1,                                            # Use single node
        gpus=1,                                                 # Use 1 GPU
    )
    trainer.test(model, test_loader)

def main():
    #torch.multiprocessing.set_start_method('spawn')
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    args = parse_args()

    # Comment out Wandb mode settings as we use TensorBoard
    # os.environ['WANDB_MODE'] = args.wandb_mode
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    
    if args.test:
        if not args.ckpt_best:
            raise ValueError("Must specify --ckpt_best parameter when using --test")
        test(args)
    else:
        train(args)


if __name__ == "__main__":
    main()