import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Union
import json
import gc
import traceback
import logging
import time
import numpy as np
import warnings
from collections import defaultdict
warnings.filterwarnings("ignore")

# Set up logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[logging.FileHandler("lgcd_colorectal_qa_debug.log"), logging.StreamHandler()])
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

def clear_gpu_memory():
    """Thoroughly clear GPU memory between test runs"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    logger.info("GPU memory cleared")

# GPU setup
gpu_id = 0  # Changed to 0 for better compatibility
DEVICE = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu"

# sample question for each language
COLORECTAL_CANCER_QUESTIONS = {
    'ko': "대장암 증상은 무엇이 있을까요?",
    'zh': "结肠癌的症状有哪些？",
    'de': "Was sind die Symptome von Darmkrebs?",
    'pt': "Quais são os sintomas do câncer colorretal?",
    'ar': "ما هي أعراض سرطان القولون والمستقيم؟",
    'fa': "علائم سرطان کولورکتال چیست؟",
    'ja': "大腸がんの症状は何ですか？",
    'id': "Apa saja gejala kanker kolorektal?",
    'sw': "Dalili za saratani ya utumbo ni zipi?",
    'en': "What are the symptoms of colorectal cancer?"
}

@dataclass
class LGCDConfig:
    """Configuration for LoRA-Gated Contrastive Decoding"""
    # LoRA extraction parameters
    lora_rank: int = 32  # Rank for LoRA decomposition
    lora_alpha: float = 32  # LoRA scaling factor
    lora_dropout: float = 0.1  # LoRA dropout rate
    
    # Layer group selection parameters
    layer_group: str = 'all'  # 'all', 'lower', 'middle', 'upper'
    
    # Confidence-based gating parameters
    confidence_threshold: float = 0.7  # Threshold for using base vs LoRA
    confidence_method: str = 'max_prob'  # 'entropy', 'max_prob', 'variance'
    adaptive_threshold: bool = False  # Dynamically adjust threshold
    
    # Enhanced contrastive decoding parameters
    contrastive_alpha: float = 0.1  # Weight for base model in contrastive
    contrastive_beta: float = 1.0   # Weight for LoRA model in contrastive
    
    # Top-K masking parameters
    contrastive_top_k: int = 100     # Apply contrastive only to top-k tokens
    generation_top_k: int = 100      # Top-k for final generation
    
    # Generation parameters
    max_length: int = 2048
    generation_temperature: float = 0.7
    top_k: int = 100
    top_p: float = 0.9
        
    device: str = DEVICE

class LoRAExtractor:
    """Extract LoRA parameters from fine-tuned model"""
    
    def __init__(self, config: LGCDConfig):
        self.config = config
        self.lora_params = {}
        
    def get_layer_groups(self, layer_names: List[str]) -> List[str]:
        """Divide layers into groups and return selected group"""
        # Extract layer numbers from layer names
        layer_numbers = []
        for name in layer_names:
            # Extract layer number from names like "model.layers.0.mlp.gate_proj.weight"
            parts = name.split('.')
            for i, part in enumerate(parts):
                if part == 'layers' and i + 1 < len(parts):
                    try:
                        layer_num = int(parts[i + 1])
                        layer_numbers.append((layer_num, name))
                        break
                    except ValueError:
                        continue
        
        if not layer_numbers:
            logger.warning("No layer numbers found, using all layers")
            return layer_names
        
        # Sort by layer number
        layer_numbers.sort(key=lambda x: x[0])
        sorted_names = [name for _, name in layer_numbers]
        
        total_layers = len(sorted_names)
        third = total_layers // 3
        
        if self.config.layer_group == 'lower':
            selected_names = sorted_names[:third]
            logger.info(f"Selected lower 1/3 layers: {len(selected_names)} layers (0-{third-1})")
        elif self.config.layer_group == 'middle':
            selected_names = sorted_names[third:2*third]
            logger.info(f"Selected middle 1/3 layers: {len(selected_names)} layers ({third}-{2*third-1})")
        elif self.config.layer_group == 'upper':
            selected_names = sorted_names[2*third:]
            logger.info(f"Selected upper 1/3 layers: {len(selected_names)} layers ({2*third}-{total_layers-1})")
        else:  # 'all'
            selected_names = sorted_names
            logger.info(f"Selected all layers: {len(selected_names)} layers")
        
        return selected_names
        
    def extract_lora_from_knowledge(self, language_model, knowledge_model) -> Dict:
        """Extract LoRA parameters from the difference between models"""
        logger.info("Extracting LoRA parameters from fine-tuned model...")
        
        lora_params = {}
        extraction_stats = defaultdict(list)
        
        base_state = language_model.state_dict()
        knowledge_state = knowledge_model.state_dict()
        
        # Extract LoRA for attention and MLP layers
        target_modules = ['gate_proj', 'up_proj', 'down_proj']
        
        # Find all candidate layer names
        candidate_names = []
        for name, knowledge_param in knowledge_state.items():
            if any(target in name for target in target_modules) and name in base_state:
                candidate_names.append(name)
        
        # Apply layer group selection
        selected_names = self.get_layer_groups(candidate_names)
        selected_names_set = set(selected_names)
        
        for name, knowledge_param in knowledge_state.items():
            if name in selected_names_set:
                base_param = base_state[name]
                
                # Compute delta (difference between models)
                delta = knowledge_param - base_param
                
                # SVD decomposition for LoRA approximation
                if delta.dim() == 2:  # Linear layer weights
                    try:
                        U, S, Vh = torch.linalg.svd(delta.float())
                        
                        # Select top-k singular values
                        rank = min(self.config.lora_rank, min(delta.shape))
                        
                        # Compute LoRA matrices
                        sqrt_s = torch.sqrt(S[:rank])
                        lora_A = U[:, :rank] @ torch.diag(sqrt_s)
                        lora_B = torch.diag(sqrt_s) @ Vh[:rank, :]
                        
                        # Store LoRA parameters
                        lora_params[name] = {
                            'lora_A': lora_A.to(delta.dtype).to(self.config.device),
                            'lora_B': lora_B.to(delta.dtype).to(self.config.device),
                            'scaling': self.config.lora_alpha / rank,
                            'original_shape': delta.shape,
                            'rank': rank,
                            'explained_variance': float(torch.sum(S[:rank]) / torch.sum(S))
                        }
                        
                        extraction_stats['explained_variance'].append(lora_params[name]['explained_variance'])
                        extraction_stats['rank'].append(rank)
                        
                    except Exception as e:
                        logger.warning(f"Failed to extract LoRA for {name}: {e}")
                        continue
        
        self.lora_params = lora_params
        
        # Log extraction statistics
        if extraction_stats['explained_variance']:
            avg_variance = np.mean(extraction_stats['explained_variance'])
            avg_rank = np.mean(extraction_stats['rank'])
            logger.info(f"LoRA extraction completed for {self.config.layer_group} layers: {len(lora_params)} layers, "
                       f"avg explained variance: {avg_variance:.3f}, avg rank: {avg_rank:.1f}")
        
        return lora_params

class ConfidenceEstimator:
    """Estimate confidence for gating decisions"""
    
    def __init__(self, config: LGCDConfig):
        self.config = config
        self.confidence_history = []
        
    def compute_confidence(self, logits: torch.Tensor, method: str = None) -> torch.Tensor:
        """Compute confidence scores for given logits"""
        if method is None:
            method = self.config.confidence_method
            
        if method == 'entropy':
            probs = F.softmax(logits, dim=-1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
            # Normalize entropy to [0, 1] confidence
            max_entropy = torch.log(torch.tensor(logits.shape[-1], dtype=logits.dtype))
            confidence = 1.0 - (entropy / max_entropy)
            
        elif method == 'max_prob':
            confidence = torch.max(F.softmax(logits, dim=-1), dim=-1)[0]
            
        elif method == 'variance':
            probs = F.softmax(logits, dim=-1)
            variance = torch.var(probs, dim=-1)
            confidence = 1.0 / (1.0 + variance)
            
        else:
            raise ValueError(f"Unknown confidence method: {method}")
        
        return confidence
    
    def adaptive_threshold(self, current_confidence: float) -> float:
        """Dynamically adjust confidence threshold"""
        if not self.config.adaptive_threshold:
            return self.config.confidence_threshold
        
        # Update confidence history
        self.confidence_history.append(current_confidence)
        if len(self.confidence_history) > 100:  # Keep recent history
            self.confidence_history.pop(0)
        
        # Adjust threshold based on recent confidence distribution
        if len(self.confidence_history) > 10:
            mean_confidence = np.mean(self.confidence_history)
            std_confidence = np.std(self.confidence_history)
            
            # Adaptive threshold: mean - 0.5 * std
            adaptive_thresh = max(0.1, mean_confidence - 0.5 * std_confidence)
            return min(0.9, adaptive_thresh)
        
        return self.config.confidence_threshold

class LGCDQAModel:
    """LoRA-Gated Contrastive Decoding Model for QA tasks"""
    
    def __init__(self, language_model, base_tokenizer, config: LGCDConfig):
        self.language_model = language_model
        self.base_tokenizer = base_tokenizer
        self.lgcd_config = config
        
        # Store LoRA parameters (will be set externally)
        self.lora_params = {}
        
        # Initialize confidence estimator
        self.confidence_estimator = ConfidenceEstimator(config)
        
        logger.info(f"Initialized LGCD QA Model")
    
    def set_lora_params(self, lora_params: Dict):
        """Set LoRA parameters after extraction"""
        self.lora_params = lora_params
        logger.info(f"Set LoRA parameters for {len(lora_params)} layers")
   
    def get_top_k_mask(self, logits: torch.Tensor, k: int = None) -> torch.Tensor:
        """Create mask for top-k tokens based on base model logits"""
        if k is None:
            k = self.lgcd_config.contrastive_top_k
        
        # Get top-k indices
        _, top_k_indices = torch.topk(logits, k, dim=-1)
        
        # Create boolean mask
        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask.scatter_(-1, top_k_indices, True)
        
        return mask
    
    def apply_contrastive_with_topk_masking(self, base_logits: torch.Tensor, 
                                          knowledge_logits: torch.Tensor) -> torch.Tensor:
        """Apply contrastive decoding only to top-k tokens from base model"""
        # Get top-k mask based on base model logits
        top_k_mask = self.get_top_k_mask(base_logits, self.lgcd_config.contrastive_top_k)
        
        # Compute contrastive delta
        contrastive_delta = (self.lgcd_config.contrastive_beta * knowledge_logits -
                           self.lgcd_config.contrastive_alpha * base_logits)
        
        # Start with base logits
        adjusted_logits = base_logits.clone()
        
        # Apply contrastive adjustment only to top-k tokens
        adjusted_logits[top_k_mask] += contrastive_delta[top_k_mask]
        
        # Mask out non-top-k tokens (set to -inf for probability 0)
        adjusted_logits[~top_k_mask] = float('-inf')
        
        return adjusted_logits
    
    def apply_lora_delta_to_weight(self, weight: torch.Tensor, layer_name: str) -> torch.Tensor:
        """Apply LoRA delta directly to weight matrix"""
        if layer_name not in self.lora_params:
            return weight
        
        lora_info = self.lora_params[layer_name]
        lora_A = lora_info['lora_A']
        lora_B = lora_info['lora_B']
        scaling = lora_info['scaling']
        
        # Reconstruct the LoRA delta: delta = scaling * lora_A @ lora_B
        lora_delta = scaling * torch.matmul(lora_A, lora_B)
        
        # Add delta to original weight
        modified_weight = weight + lora_delta
        
        return modified_weight
    
    def get_layer_module(self, model, layer_name: str):
        """Get the actual module from layer name"""
        parts = layer_name.split('.')
        
        # Remove 'weight' from the end if present
        if parts[-1] == 'weight':
            parts = parts[:-1]
        
        # Navigate to the module
        module = model
        for part in parts:
            if hasattr(module, part):
                module = getattr(module, part)
            else:
                return None
        
        return module
    
    def apply_lora_to_forward_pass(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Apply LoRA modifications by temporarily modifying model weights"""
        
        # Store original weights for restoration
        original_weights = {}
        modified_modules = []
        
        try:
            # Apply LoRA deltas to model weights
            for layer_name in self.lora_params.keys():
                module = self.get_layer_module(self.language_model, layer_name)
                
                if module is not None and hasattr(module, 'weight'):
                    # Store original weight
                    original_weights[layer_name] = module.weight.data.clone()
                    
                    # Apply LoRA delta
                    modified_weight = self.apply_lora_delta_to_weight(
                        module.weight.data, layer_name
                    )
                    
                    # Temporarily replace the weight
                    module.weight.data = modified_weight
                    modified_modules.append((layer_name, module))
            
            # Forward pass with modified weights
            with torch.no_grad():
                outputs = self.language_model(input_ids)
                approximated_logits = outputs.logits[:, -1, :]
            
            return approximated_logits
            
        finally:
            # Restore original weights
            for layer_name, module in modified_modules:
                if layer_name in original_weights:
                    module.weight.data = original_weights[layer_name]
    
    def generate_with_lgcd(self, question: str, max_new_tokens: int = 200) -> Tuple[str, str]:
        """Generate response using LGCD approach"""
        try:
            # Format the question
            if "instruct" in self.language_model.config.name_or_path.lower():
                try:
                    messages = [{"role": "user", "content": question}]
                    formatted_question = self.base_tokenizer.apply_chat_template(
                        messages, 
                        tokenize=False, 
                        add_generation_prompt=True
                    )
                except:
                    formatted_question = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            else:
                formatted_question = question
            
            # Tokenize input
            inputs = self.base_tokenizer(
                formatted_question,
                return_tensors="pt",
                truncation=True,
                max_length=self.lgcd_config.max_length
            ).to(self.language_model.device)
            
            input_ids = inputs['input_ids']
            generated_tokens = []
            decision_path = []
            
            # Generate tokens one by one using LGCD
            for _ in range(max_new_tokens):
                with torch.no_grad():
                    # Base model forward pass
                    base_outputs = self.language_model(input_ids)
                    base_logits = base_outputs.logits[:, -1, :]
                    
                    # Compute base model confidence
                    base_confidence = self.confidence_estimator.compute_confidence(base_logits)
                    
                    # Get adaptive threshold
                    current_threshold = self.confidence_estimator.adaptive_threshold(
                        base_confidence.mean().item()
                    )
                    
                    # Gating decision: use base model if confidence is high
                    if base_confidence.mean() > current_threshold:
                        # Apply top-k masking for base model
                        top_k_mask = self.get_top_k_mask(base_logits, self.lgcd_config.generation_top_k)
                        masked_logits = base_logits.clone()
                        masked_logits[~top_k_mask] = float('-inf')
                        final_logits = masked_logits
                        decision = 'base'
                    else:
                        # Low confidence: use LGCD
                        approximated_knowledge_logits = self.apply_lora_to_forward_pass(input_ids)
                        
                        # Apply contrastive decoding with top-k masking
                        final_logits = self.apply_contrastive_with_topk_masking(
                            base_logits, approximated_knowledge_logits
                        )
                        decision = 'contrastive'
                    
                    decision_path.append(decision)
                    
                    # Sample next token
                    probs = F.softmax(final_logits / self.lgcd_config.generation_temperature, dim=-1)
                    
                    if self.lgcd_config.top_p < 1.0:
                        # Apply top-p sampling
                        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                        
                        # Remove tokens with cumulative probability above the threshold
                        sorted_indices_to_remove = cumulative_probs > self.lgcd_config.top_p
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = 0
                        
                        # Scatter sorted tensors to original indexing
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        probs[indices_to_remove] = 0
                        probs = probs / probs.sum(dim=-1, keepdim=True)
                    
                    next_token = torch.multinomial(probs, num_samples=1)
                    
                    # Check for termination
                    if next_token.item() == self.base_tokenizer.eos_token_id:
                        break
                    
                    generated_tokens.append(next_token.item())
                    input_ids = torch.cat([input_ids, next_token], dim=-1)
            
            # Decode response
            response = self.base_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
            
            # Create decision summary
            decision_summary = f"Base: {decision_path.count('base')}, Contrastive: {decision_path.count('contrastive')}"
            
            return response, decision_summary
            
        except Exception as e:
            logger.error(f"Error in LGCD generation: {str(e)}")
            return f"Error: {str(e)}", "error"
    
    def cleanup(self):
        """Clean up model resources"""
        del self.language_model
        del self.base_tokenizer
        self.lora_params.clear()
        clear_gpu_memory()

# Model configurations
MODEL_PAIRS = {
    'zh': [ 
        {'language_model': 'hfl/llama-3-chinese-8b-instruct', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
        {'language_model': 'shenzhi-wang/Llama3-8B-Chinese-Chat', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'de': [ 
        {'language_model': 'DiscoResearch/Llama3-DiscoLeo-Instruct-8B-v0.1', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'pt': [
        {'language_model': 'rhaymison/gemma-portuguese-luana-2b', 'knowledge_model': 'google/gemma-2b-it'},
    ],
    'ar': [ 
        {'language_model': 'MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'fa': [
        {'language_model': 'PartAI/Dorna-Llama3-8B-Instruct', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'ja': [
        {'language_model': 'elyza/Llama-3-ELYZA-JP-8B', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
        {'language_model': 'tokyotech-llm/Llama-3-Swallow-8B-Instruct-v0.1', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'ko': [
        {'language_model': 'KISTI-KONI/KONI-Llama3-8B-Instruct-20240729', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
        {'language_model': 'MLP-KTLim/llama-3-Korean-Bllossom-8B', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'id': [
        {'language_model': 'GoToCompany/llama3-8b-cpt-sahabatai-v1-instruct', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
    'sw': [
        {'language_model': 'Jacaranda/UlizaLlama3', 'knowledge_model': 'meta-llama/Meta-Llama-3-8B-Instruct'},
    ],
}

def load_models(language_model_id, knowledge_model_id):
    """Load tokenizers and models"""
    logger.info(f"Loading tokenizers...")
    base_tokenizer = AutoTokenizer.from_pretrained(language_model_id)
    knowledge_tokenizer = AutoTokenizer.from_pretrained(knowledge_model_id)
    
    # Set proper padding tokens
    if base_tokenizer.pad_token is None:
        base_tokenizer.pad_token = base_tokenizer.eos_token
    if knowledge_tokenizer.pad_token is None:
        knowledge_tokenizer.pad_token = knowledge_tokenizer.eos_token
        
    logger.info(f"Loading models...")
    language_model = AutoModelForCausalLM.from_pretrained(
        language_model_id, 
        torch_dtype=torch.bfloat16,
        device_map=DEVICE
    )
    knowledge_model = AutoModelForCausalLM.from_pretrained(
        knowledge_model_id, 
        torch_dtype=torch.bfloat16,
        device_map=DEVICE
    )

    language_model.eval()
    knowledge_model.eval()
    
    logger.info("Models loaded successfully")
    
    return base_tokenizer, knowledge_tokenizer, language_model, knowledge_model

def extract_lora_and_cleanup_knowledge_model(language_model, knowledge_model, config: LGCDConfig):
    """Extract LoRA parameters and immediately clean up knowledge model"""
    logger.info("Extracting LoRA parameters...")
    
    # Extract LoRA parameters
    lora_extractor = LoRAExtractor(config)
    lora_params = lora_extractor.extract_lora_from_knowledge(language_model, knowledge_model)
    
    # Immediately cleanup knowledge model to free memory
    logger.info("Cleaning up knowledge model from memory...")
    del knowledge_model
    clear_gpu_memory()
    
    logger.info(f"LoRA extraction completed. Knowledge model removed from memory.")
    return lora_params

def run_lgcd_qa_evaluation(language_model_id: str, knowledge_model_id: str, 
                          language: str, question: str, config: LGCDConfig) -> Dict:
    """Run LGCD QA evaluation for a specific model pair"""
    try:
        # Load models
        base_tokenizer, knowledge_tokenizer, language_model, knowledge_model = load_models(
            language_model_id, knowledge_model_id
        )
        
        # Extract LoRA parameters and cleanup knowledge model immediately
        lora_params = extract_lora_and_cleanup_knowledge_model(
            language_model, knowledge_model, config
        )
        
        # Also cleanup knowledge tokenizer as it's no longer needed
        del knowledge_tokenizer
        clear_gpu_memory()
        
        # Create LGCD QA model (now only needs language model and tokenizer)
        lgcd_model = LGCDQAModel(
            language_model=language_model,
            base_tokenizer=base_tokenizer,
            config=config
        )
        
        # Set the extracted LoRA parameters
        lgcd_model.set_lora_params(lora_params)
        
        # Generate response using LGCD
        logger.info(f"Generating LGCD response for {language} using {language_model_id}")
        response, decision_summary = lgcd_model.generate_with_lgcd(question)
        print(response)
        
        result = {
            'language': language,
            'language_model_id': language_model_id,
            'knowledge_model_id': knowledge_model_id,
            'question': question,
            'response': response,
            'decision_summary': decision_summary,
            'config': {
                'confidence_threshold': config.confidence_threshold,
                'contrastive_alpha': config.contrastive_alpha,
                'contrastive_beta': config.contrastive_beta,
                'lora_rank': config.lora_rank,
                'layer_group': config.layer_group
            },
            'status': 'success',
            'lora_layers_count': len(lora_params)
        }
        
        # Cleanup
        lgcd_model.cleanup()
        
        return result
        
    except Exception as e:
        logger.error(f"Error with LGCD QA for {language}: {str(e)}")
        return {
            'language': language,
            'language_model_id': language_model_id,
            'knowledge_model_id': knowledge_model_id,
            'question': question,
            'response': f"Error: {str(e)}",
            'decision_summary': "error",
            'status': 'error'
        }

def save_results_to_file(results: List[Dict], filename: str = "lgcd_colorectal_cancer_qa_results.json"):
    """Save results to JSON file"""
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"{timestamp}_{filename}"
    
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    logger.info(f"Results saved to {filename}")

def print_results_summary(results: List[Dict]):
    """Print a summary of results"""
    print("\n" + "="*80)
    print("LGCD-BASED COLORECTAL CANCER SYMPTOMS - MULTILINGUAL QA RESULTS")
    print("="*80)
    
    for result in results:
        print(f"\nLanguage: {result['language'].upper()}")
        print(f"LAM: {result['language_model_id']}")
        print(f"PTM: {result['knowledge_model_id']}")
        print(f"Question: {result['question']}")
        print(f"Status: {result['status']}")
        if 'decision_summary' in result:
            print(f"LGCD choice: {result['decision_summary']}")
        if 'lora_layers_count' in result:
            print(f"# of LoRA: {result['lora_layers_count']}")
        print(f"Final answer:")
        print("-" * 60)
        print(result['response'])
        print("-" * 60)

def main():
    """Main function to run LGCD-based multilingual colorectal cancer QA"""
    logger.info("Starting Memory-Optimized LGCD-based Multilingual Colorectal Cancer QA System")
    
    # LGCD configuration
    config = LGCDConfig(
        confidence_threshold=0.7,
        contrastive_alpha=0.1,
        contrastive_beta=1.0,
        lora_rank=32,
        layer_group='all',
        device=DEVICE
    )
    
    all_results = []
    
    # Test each language and model pair
    for lang, model_pairs in MODEL_PAIRS.items():
        if lang not in COLORECTAL_CANCER_QUESTIONS:
            logger.warning(f"No question defined for language: {lang}")
            continue
            
        question = COLORECTAL_CANCER_QUESTIONS[lang]
        
        for pair_idx, model_pair_info in enumerate(model_pairs):
            language_model_id = model_pair_info['language_model']  # Language-specific model as base
            knowledge_model_id = model_pair_info['knowledge_model']  # Knowledge model as knowledge
            
            logger.info(f"Testing LGCD for {lang} - Pair {pair_idx + 1}: {language_model_id} + {knowledge_model_id}")
            
            try:
                result = run_lgcd_qa_evaluation(
                    language_model_id, knowledge_model_id, lang, question, config
                )
                all_results.append(result)
                
                # Add delay between models to ensure memory cleanup
                time.sleep(5)
                
            except Exception as e:
                logger.error(f"Fatal error with LGCD for {lang}: {str(e)}")
                logger.error(traceback.format_exc())
                
                error_result = {
                    'language': lang,
                    'language_model_id': language_model_id,
                    'knowledge_model_id': knowledge_model_id,
                    'question': question,
                    'response': f"Fatal error: {str(e)}",
                    'decision_summary': "fatal_error",
                    'status': 'fatal_error'
                }
                all_results.append(error_result)
            
            # Clear memory between models
            clear_gpu_memory()
    
    # Save and display results
    save_results_to_file(all_results)
    print_results_summary(all_results)
    
    # Print summary statistics
    successful_runs = sum(1 for r in all_results if r['status'] == 'success')
    total_runs = len(all_results)
    
    print(f"\n{'='*80}")
    print(f"SUMMARY: {successful_runs}/{total_runs} successful LGCD runs")
    print(f"Languages tested: {len(set(r['language'] for r in all_results))}")
    print(f"Model pairs tested: {len(set((r['language_model_id'], r['knowledge_model_id']) for r in all_results))}")
    print(f"{'='*80}")
    

if __name__ == "__main__":
    main()