import os
import sys
from tqdm import tqdm
import json
import torch
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM
from qwen_vl_utils import process_vision_info
import argparse

# System prompt for the model
SYSTEM_PROMPT = "You are a helpful assistant."

def setup_logging(log_path):
    """Setup logging configuration"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_path, encoding='utf-8'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    return logging.getLogger(__name__)

def validate_paths(args):
    """Validate input paths and create necessary directories"""
    # Check if dataset exists
    if not os.path.exists(args.dataset_path):
        raise FileNotFoundError(f"Dataset file not found: {args.dataset_path}")
    
    # Check if image base path exists
    if not os.path.exists(args.image_base_path):
        raise FileNotFoundError(f"Image base path not found: {args.image_base_path}")
    
    # Check if model paths exist
    if not os.path.exists(args.vl_model_path):
        raise FileNotFoundError(f"VL model path not found: {args.vl_model_path}")
    
    if not os.path.exists(args.llm_model_path):
        raise FileNotFoundError(f"LLM model path not found: {args.llm_model_path}")
    
    # Create results directory
    results_dir = os.path.dirname(args.output_path)
    os.makedirs(results_dir, exist_ok=True)
    
    # Create log directory
    log_dir = os.path.dirname(args.log_path)
    os.makedirs(log_dir, exist_ok=True)

def parse_args():
    parser = argparse.ArgumentParser(description='Multimodal Math Problem Solver (ADM)')
    
    # Data paths
    parser.add_argument('--image_base_path', type=str, 
                        default='./data/mathverse/images',
                        help='Base path for images')
    parser.add_argument('--dataset_path', type=str,
                        default='./data/mathverse/testmini.json',
                        help='Path to dataset JSON file')
    parser.add_argument('--exp_name', type=str, required=True,
                        help='Experiment name for organizing results')
    parser.add_argument('--output_path', type=str,
                        default=None,
                        help='Path to save results (auto-generated if not specified)')
    parser.add_argument('--log_path', type=str,
                        default=None,
                        help='Path to save log file (auto-generated if not specified)')
    
    # Model paths
    parser.add_argument('--vl_model_path', type=str,
                        default="Qwen/Qwen2.5-VL-3B-Instruct",
                        help='Path to Vision Pathway (with image)')
    
    parser.add_argument('--llm_model_path', type=str,
                        default="Qwen/Qwen2.5-VL-3B-Instruct",
                        help='Path to Textual Pathway (without image)')
    
    # Model hyperparameters 
    parser.add_argument('--alpha', type=float, default=0.51,
                        help='Weight factor for mixed sequence generation')
    parser.add_argument('--max_new_tokens', type=int, default=2048,
                        help='Maximum number of new tokens to generate')
    parser.add_argument('--temperature', type=float, default=0.1,
                        help='Sampling temperature')
    parser.add_argument('--top_p', type=float, default=0.95,
                        help='Top-p sampling parameter')
    parser.add_argument('--prefill', type=str, default="",
                        help='Prefill text for generation')
    
    # Additional options
    parser.add_argument('--device_0', type=str, default="cuda:0",
                        help='Device for Vison Pathway')
    parser.add_argument('--device_1', type=str, default="cuda:1",
                        help='Device for Textual Pathway')
    parser.add_argument('--caption_path', type=str, 
                        default='./data/results_caption_qwen2_5vl_3b.json',
                        help='Path to caption results file')
    parser.add_argument('--subset_config_path', type=str,
                        default='./data/mathverse/subset_config.json',
                        help='Path to subset configuration file (optional)')
    parser.add_argument('--use_subset', action='store_true',
                        help='Use subset of problems defined in subset_config_path')
    
    return parser.parse_args()

def load_models(vl_model_path, llm_model_path, device_0, device_1):
    """Load and initialize both vision-language models"""
    logger = logging.getLogger(__name__)
    
    try:
        logger.info(f"Loading VL model from: {vl_model_path}")
        vl_processor = AutoProcessor.from_pretrained(vl_model_path, max_pixels=1000000, min_pixels=3136)
        vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            vl_model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="eager", 
            output_attentions=True,
        ).eval().to(device_0)
        logger.info(f"VL model loaded successfully on {device_0}")
        
        logger.info(f"Loading LLM model from: {llm_model_path}")
        llm_processor = AutoProcessor.from_pretrained(llm_model_path, max_pixels=1000000, min_pixels=3136)
        llm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            llm_model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            output_attentions=False,
        ).eval().to(device_1)
        logger.info(f"LLM model loaded successfully on {device_1}")
        
        return vl_model, vl_processor, llm_model, llm_processor
        
    except Exception as e:
        logger.error(f"Error loading models: {str(e)}")
        raise

def load_dataset(dataset_path, caption_path):
    """
    Load dataset and captions, organizing data by problem index
    """
    logger = logging.getLogger(__name__)
    
    try:
        logger.info(f"Loading dataset from: {dataset_path}")
        with open(dataset_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        logger.info(f"Loading captions from: {caption_path}")
        with open(caption_path, 'r', encoding='utf-8') as f:
            captions = json.load(f)
        
        pidx2data = {}
        
        # Process Vision Dominant items
        for item in data:
            pidx = item['problem_index']
            version = item['problem_version']
            if pidx not in pidx2data:
                pidx2data[pidx] = {}
            if version == 'Vision Dominant':
                pidx2data[pidx]['vd'] = item['question']
                pidx2data[pidx]['answer'] = item['answer']
                pidx2data[pidx]['question_for_eval'] = item['question_for_eval']
                pidx2data[pidx]['image'] = item['image']
            pidx2data[pidx]['metadata'] = item['metadata']
        
        # Process Text Dominant items
        for item in data:
            pidx = item['problem_index']
            version = item['problem_version']
            if version == 'Text Dominant':
                caption_idx = int(pidx) - 1
                if caption_idx < len(captions):
                    caption_text = captions[caption_idx]["generated_text"]
                    pidx2data[pidx]['td'] = f"As shown in the image: {caption_text}Question: {pidx2data[pidx]['vd']}"
                else:
                    logger.warning(f"Caption index {caption_idx} out of range for problem {pidx}")
                    pidx2data[pidx]['td'] = f"Question: {pidx2data[pidx]['vd']}"
        
        logger.info(f"Loaded {len(pidx2data)} problems from dataset")
        return pidx2data
        
    except FileNotFoundError as e:
        logger.error(f"File not found: {str(e)}")
        raise
    except json.JSONDecodeError as e:
        logger.error(f"JSON decode error: {str(e)}")
        raise
    except Exception as e:
        logger.error(f"Error loading dataset: {str(e)}")
        raise

@torch.no_grad()
def generate_mixed_sequence(
        vl_model, vl_processor,
        llm_model, llm_processor,
        question_vl, question_llm,
        image,
        device_0, device_1,
        prefill="<think>\nOkay",
        alpha=0.51,
        max_new_tokens=2048,
        temperature=0.1,
        top_p=0.95,
        eos_token_id=None
    ):
    """Generate text using a mixture of Vision Pathway and Textual Pathway"""
    vl_model.rope_deltas = None
    llm_model.rope_deltas = None
    # Prepare Vision Pathway input
    vl_messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT,
                },
            ]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image
                },
                {
                    "type": "text",
                    "text": question_vl
                },
            ],
        }
    ]
    question_vl = vl_processor.apply_chat_template(
        vl_messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, _ = process_vision_info(vl_messages)
    input_data = vl_processor(
        text=[question_vl],
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Find vision token boundaries
    tokens = vl_processor.tokenizer.convert_ids_to_tokens(input_data['input_ids'][0])
    vision_start_index = tokens.index("<|vision_start|>")
    vision_end_index = tokens.index("<|vision_end|>")
    first_image_token_index = vision_start_index + 1
    last_image_token_index = vision_end_index - 1

    input_data = input_data.to(device_0)
    
    # Prepare Textual Pathway model input
    llm_messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT,
                },
            ]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": question_llm
                },
            ],
        }
    ]
    question_llm = llm_processor.apply_chat_template(
        llm_messages, tokenize=False, add_generation_prompt=True
    )

    input_ids_llm = llm_processor(
        text=[question_llm],
        padding=True,
        return_tensors="pt",
    )
    input_ids_llm = input_ids_llm['input_ids'].to(device_1)

    if eos_token_id is None:
        eos_token_id = llm_processor.tokenizer.eos_token_id

    def get_mixed_probs(logits_vl, logits_llm, temperature, top_p, ratio, pre_ratio):
        """Compute mixed probabilities from Vision Pathway and Textual Pathway logits"""
        # Temperature scaling and softmax
        p_vl = torch.softmax(logits_vl / temperature, dim=-1)
        p_llm = torch.softmax(logits_llm / temperature, dim=-1).to(device_0)
    
        # Get vocabulary sizes
        vl_vocab_size = logits_vl.shape[-1]
        llm_vocab_size = logits_llm.shape[-1]
    
        # Create padded tensors to match sizes
        if vl_vocab_size > llm_vocab_size:
            p_llm_padded = torch.zeros_like(p_vl)
            p_llm_padded[..., :llm_vocab_size] = p_llm
            p_llm = p_llm_padded
        elif llm_vocab_size > vl_vocab_size:
            p_vl_padded = torch.zeros_like(p_llm)
            p_vl_padded[..., :vl_vocab_size] = p_vl
            p_vl = p_vl_padded
        
        # Mix probabilities
        # modify the alpha
        def map_ratio_to_range(r: float) -> float:

            lower_bound = -0.01
            upper_bound = 0.01
            k = (upper_bound - lower_bound)/ 0.1
            if not 0 <= r <= 1:
                print("out of range")
            return lower_bound + k * r
        alpha_in_use = 0
        if ratio == -1:
            alpha_in_use = alpha
        else:
            alpha_in_use =  alpha + map_ratio_to_range(ratio)
            gamma = 0.9
            alpha_in_use = gamma * alpha_in_use + (1-gamma) * pre_ratio
        
        p_mix = alpha_in_use * p_vl + (1 - alpha_in_use) * p_llm
        p_mix = torch.clamp(p_mix, min=1e-10)
        p_mix = p_mix / p_mix.sum(dim=-1, keepdim=True)
        
        # Top-p sampling
        if top_p < 1.0:
            sorted_probs, sorted_indices = torch.sort(p_mix, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            mask = cumulative_probs > top_p
            mask[..., 0] = False
            sorted_probs[mask] = 0.0
            p_mix.zero_()
            p_mix.scatter_(1, sorted_indices, sorted_probs)
            p_mix = p_mix / p_mix.sum(dim=-1, keepdim=True)
        
        return p_mix, alpha_in_use

    # Initial forward pass
    outputs_vl = vl_model(
        input_ids=input_data['input_ids'],
        pixel_values=input_data['pixel_values'],
        image_grid_thw=input_data['image_grid_thw'],
        use_cache=True
    )
    # print(outputs_vl.keys())
    outputs_llm = llm_model(
        input_ids=input_ids_llm,
        use_cache=True
    )

    # Get initial states
    past_key_values_vl = outputs_vl.past_key_values
    past_key_values_llm = outputs_llm.past_key_values
    
    # Process prefill tokens
    prefill_tokens = []
    if prefill:
        prefill_ids = llm_processor.tokenizer(prefill, return_tensors="pt")["input_ids"][0]
        prefill_tokens = prefill_ids.tolist()

    # Calculate initial token
    if prefill_tokens:
        next_token = torch.tensor([[prefill_tokens[0]]], device=device_0)
    else:
        p_mix, _ = get_mixed_probs(
            outputs_vl.logits[:, -1, :],
            outputs_llm.logits[:, -1, :],
            temperature,
            top_p,
            -1,
            -1
        )
        next_token = torch.multinomial(p_mix, num_samples=1)

    # Initialize position info and output list
    current_length_vl = input_data['input_ids'].shape[1]
    current_length_llm = input_ids_llm.shape[1]
    cache_position_vl = torch.tensor([current_length_vl], device=device_0)
    cache_position_llm = torch.tensor([current_length_llm], device=device_1)
    generated_tokens = [next_token.cpu().item()]
    
    # Main generation loop
    image_ratio = []
    pre_ratio = alpha
    
    for idx in range(max_new_tokens - 1):
        # Forward pass for Vision Pathway and Textual Pathway
        outputs_vl = vl_model(
            input_ids=next_token.to(device_0),
            past_key_values=past_key_values_vl,
            use_cache=True,
            cache_position=cache_position_vl
        )
        
        # Extract attention weights from first layer
        attentions = outputs_vl.attentions
        attention = attentions[0]
        
        # Calculate attention ratio for image tokens vs all tokens
        attention_image = attention[:, :, :, first_image_token_index:last_image_token_index+1].mean(dim=-1)
        attention_all = attention[:, :, :, :].mean(dim=-1) 
        
        # Aggregate across attention heads
        agg_attention_image = attention_image.sum(dim=1)
        agg_attention_all = attention_all.sum(dim=1) 
        attention_first = agg_attention_image.item()
        attention_rest = agg_attention_all.item()
        ratio = attention_first/attention_rest
        image_ratio.append(ratio)
        outputs_llm = llm_model(
            input_ids=next_token.to(device_1),
            past_key_values=past_key_values_llm,
            use_cache=True,
            cache_position=cache_position_llm
        )

        # Update states
        past_key_values_vl = outputs_vl.past_key_values
        past_key_values_llm = outputs_llm.past_key_values
        
        # Check if still in prefill phase
        if prefill_tokens and idx < len(prefill_tokens) - 1:
            next_token = torch.tensor([[prefill_tokens[idx + 1]]], device=device_0)
        else:
            # Calculate next token
            p_mix, pre_ratio = get_mixed_probs(
                outputs_vl.logits[:, -1, :],
                outputs_llm.logits[:, -1, :],
                temperature,
                top_p, 
                ratio,
                pre_ratio
            )
            next_token = torch.multinomial(p_mix, num_samples=1)
            
        generated_tokens.append(next_token.cpu().item())

        # Update position information
        cache_position_vl += 1
        cache_position_llm += 1

        if next_token.item() == eos_token_id:
            break

    return llm_processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

def main():
    """Main execution function"""
    try:
        args = parse_args()
        
        # Auto-generate output and log paths based on exp_name
        if args.output_path is None:
            args.output_path = f"results/{args.exp_name}/result_mathverse.json"
        if args.log_path is None:
            args.log_path = f"results/{args.exp_name}/log_mathverse.txt"
        
        # Validate paths
        validate_paths(args)
        
        # Setup logging
        logger = setup_logging(args.log_path)
        logger.info("Starting multimodal math problem solver (S-ATM)")
        logger.info(f"Experiment name: {args.exp_name}")
        logger.info(f"Alpha: {args.alpha}")
        logger.info(f"Temperature: {args.temperature}")
        logger.info(f"Top-p: {args.top_p}")
        logger.info(f"Max new tokens: {args.max_new_tokens}")
        if args.use_subset:
            logger.info(f"Using subset mode with config: {args.subset_config_path}")
        else:
            logger.info("Using full dataset")
        
        # Load models
        vl_model, vl_processor, llm_model, llm_processor = load_models(
            args.vl_model_path, args.llm_model_path, args.device_0, args.device_1
        )

        # Load dataset
        dataset = load_dataset(args.dataset_path, args.caption_path)
        
        # Load subset configuration if specified
        selected_problem_indices = None
        if args.use_subset and args.subset_config_path and os.path.exists(args.subset_config_path):
            logger.info(f"Loading subset configuration from: {args.subset_config_path}")
            with open(args.subset_config_path, 'r', encoding='utf-8') as f:
                subset_config = json.load(f)
            selected_problem_indices = set(subset_config['selected_problem_indices'])
            logger.info(f"Will process subset with {len(selected_problem_indices)} problems")
        elif args.use_subset:
            logger.warning("Subset requested but configuration file not found, using full dataset")

        results = []
        total_problems = len(selected_problem_indices) if selected_problem_indices else len(dataset)
        logger.info(f"Starting generation for {total_problems} problems")
        
        processed_count = 0
        for pidx, item_dict in tqdm(list(dataset.items()), desc="Generating responses"):
            # Skip if using subset and this problem is not selected
            if selected_problem_indices is not None and pidx not in selected_problem_indices:
                continue
            try:
                vd_question = item_dict.get("vd", None)
                to_question = item_dict.get("td", None)
                image_path = item_dict.get("image", None)
                answer = item_dict.get("answer", None)
                question_for_eval = item_dict.get("question_for_eval", None)
                metadata = item_dict.get("metadata", None)
                
                if vd_question is None or to_question is None:
                    logger.warning(f"Skipping problem {pidx}: missing questions")
                    continue
                
                if image_path is None:
                    logger.warning(f"Skipping problem {pidx}: missing image path")
                    continue
                
                full_image_path = os.path.join(args.image_base_path, image_path)
                if not os.path.exists(full_image_path):
                    logger.warning(f"Image not found: {full_image_path}")
                    continue
                
                # Generate response
                gen_text = generate_mixed_sequence(
                    vl_model, vl_processor,
                    llm_model, llm_processor, 
                    vd_question, to_question, full_image_path,
                    args.device_0, args.device_1,
                    prefill=args.prefill,
                    alpha=args.alpha,
                    max_new_tokens=args.max_new_tokens,
                    temperature=args.temperature,
                    top_p=args.top_p
                )
                
                # Log detailed results
                logger.info(f"\n{'='*60}")
                logger.info(f"Problem {pidx}")
                logger.info(f"Image: {image_path}")
                logger.info(f"Question: {to_question}")
                logger.info(f"{'-'*30} Generated Response {'-'*30}")
                logger.info(f"{gen_text}")
                logger.info(f"{'-'*30} Ground Truth {'-'*30}")
                logger.info(f"{answer}")
                logger.info(f"{'='*60}\n")
                
                results.append({
                    "pidx": pidx,
                    "generated_text": gen_text,
                    "answer": answer,
                    "question": question_for_eval,
                    "metadata": metadata,
                    "image_path": image_path
                })
                
                # Clear GPU cache
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                processed_count += 1
                    
            except Exception as e:
                logger.error(f"Error processing problem {pidx}: {str(e)}")
                continue
        
        # Save results
        logger.info(f"Saving {len(results)} results to {args.output_path}")
        with open(args.output_path, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=4)

        logger.info("Generation completed successfully")
        logger.info(f"Processed {processed_count} problems")
        if args.use_subset:
            logger.info(f"Used subset mode with {len(selected_problem_indices)} selected problems")
        logger.info(f"Results saved to: {args.output_path}")
        logger.info(f"Log saved to: {args.log_path}")
        
    except Exception as e:
        if 'logger' in locals() and logger:
            logger.error(f"Fatal error: {str(e)}")
        else:
            print(f"Fatal error: {str(e)}")
        raise

if __name__ == "__main__":
    main()