import argparse
import json
import pickle as pickle
from collections import defaultdict, Counter
from os.path import dirname, join

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

from dataset import Dictionary, VQAFeatureDataset
import base_model
from train import train
import utils

from vqa_debias_loss_functions import *


def parse_args():
    parser = argparse.ArgumentParser("Train the BottomUpTopDown model with a de-biasing method")

    # Arguments we added
    parser.add_argument(
        '--cache_features', action="store_true",
        help="Cache image features in RAM. Makes things much faster, "
             "especially if the filesystem is slow, but requires at least 48gb of RAM")
    parser.add_argument(
        '--nocp', action="store_true", help="Run on VQA-2.0 instead of VQA-CP 2.0")
    parser.add_argument(
        '--mode', default="none",
        choices=["none", "causal", "adaptive"],
        help="Kind of ensemble loss to use: none=standard VQA model, causal=CF-VQA, adaptive=Adaptive CF-VQA")
    parser.add_argument(
        '--eval_each_epoch', action="store_true",
        help="Evaluate every epoch, instead of at the end")
    
    # Arguments for the adaptive model
    parser.add_argument(
        '--adaptive_method', default="entropy", 
        choices=["entropy", "margin", "ensemble_disagreement"], 
        help="Method for uncertainty estimation: 'entropy'=distribution entropy, 'margin'=top-1/2 margin, 'ensemble_disagreement'=variance among ensemble alpha heads")
    parser.add_argument(
        '--uncertainty_hidden_dim', type=int, default=128,
        help="Hidden dimension of the uncertainty estimation network")
    parser.add_argument(
        '--alpha_reg_weight', type=float, default=0.1,
        help="Weight for the alpha regularization loss (to prevent extreme values)")
    parser.add_argument(
        '--alpha_uncertainty_sources', nargs='+', default=['vq_only'],
        choices=['vq_only', 'q_only', 'v_only'],
        help="Sources of uncertainty to combine for alpha generation: "
             "'vq_only': use base VQ logits (original behavior). "
             "'q_only': use question branch logits. "
             "'v_only': use vision branch logits. "
             "Combine multiple sources by listing them (e.g., --alpha_uncertainty_sources q_only vq_only). "
             "The uncertainty module input dim will adapt.")
    parser.add_argument(
        '--num_ensemble_heads', type=int, default=5,
        help="Number of heads for 'ensemble_disagreement' uncertainty estimation.")

    # Arguments from the original model, we leave this default, except we
    # set --epochs to 15 since the model maxes out its performance on VQA 2.0 well before then
    parser.add_argument('--epochs', type=int, default=15)
    parser.add_argument('--num_hid', type=int, default=1024)
    parser.add_argument('--model', type=str, default='baseline0_newatt')
    parser.add_argument('--output', type=str, default='saved_models/exp0')
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--seed', type=int, default=1111, help='random seed')

    # Arguments for Curriculum Learning
    parser.add_argument(
        '--use_curriculum', action="store_true",
        help="Enable curriculum learning strategy.")
    parser.add_argument(
        '--curriculum_metric', default='uncertainty_entropy',
        choices=['bias_max', 'uncertainty_entropy'],
        help="Metric used to determine sample difficulty for curriculum. "
             "'bias_max': easier samples have higher max answer bias (pre-calculated). "
             "'uncertainty_entropy': easier samples have lower prediction entropy from initial model (requires pre-calculation pass).")
    parser.add_argument(
        '--curriculum_pacing', default='linear', choices=['linear'], # Add more pacing functions later
        help="Pacing function for curriculum learning ('linear': linearly increase data percentage).")
    parser.add_argument(
        '--curriculum_start_percent', type=float, default=0.2,
        help="Initial percentage of easiest data to use in the first epoch.")
    parser.add_argument(
        '--curriculum_end_percent', type=float, default=1.0,
        help="Percentage of data to use in the final epoch.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    dictionary = Dictionary.load_from_file('data/dictionary.pkl')
    cp = not args.nocp

    print("Building train dataset...")
    train_dset = VQAFeatureDataset('train', dictionary, cp=cp,
                                   cache_image_features=args.cache_features)
    print("Building test dataset...")
    eval_dset = VQAFeatureDataset('val', dictionary, cp=cp,
                                  cache_image_features=args.cache_features)

    answer_voc_size = train_dset.num_ans_candidates

    # Compute the bias:
    # The bias here is just the expected score for each answer/question type

    # question_type -> answer -> total score
    question_type_to_probs = defaultdict(Counter)
    # question_type -> num_occurances
    question_type_to_count = Counter()
    for ex in train_dset.entries:
        ans = ex["answer"]
        q_type = ans["question_type"]
        question_type_to_count[q_type] += 1
        if ans["labels"] is not None:
            for label, score in zip(ans["labels"], ans["scores"]):
                question_type_to_probs[q_type][label] += score

    question_type_to_prob_array = {}
    for q_type, count in question_type_to_count.items():
        prob_array = np.zeros(answer_voc_size, np.float32)
        for label, total_score in question_type_to_probs[q_type].items():
            prob_array[label] += total_score
        prob_array /= count
        question_type_to_prob_array[q_type] = prob_array

    # Now add a `bias` field to each example
    for ds in [train_dset, eval_dset]:
        for ex in ds.entries:
            q_type = ex["answer"]["question_type"]
            ex["bias"] = question_type_to_prob_array[q_type]

    # Build the base model first (always UPDN now)
    print("Building UPDN (baseline0_newatt) backbone...")
    base_constructor = getattr(base_model, 'build_baseline0_newatt')
    base_model_instance = base_constructor(train_dset, args.num_hid).cuda()
    
    # Initialize embeddings
    if hasattr(base_model_instance, 'w_emb'):
        base_model_instance.w_emb.init_embedding('data/glove6b_init_300d.npy')

    # Define MLP parameters for potential side branches (Q/V classifiers in CFVQAModel)
    # Ensure dimensions match, especially the output_dim should be num_ans_candidates
    classif_q = {'input_dim': args.num_hid, 'hidden_dims': [512], 'output_dim': train_dset.num_ans_candidates}
    # Input dim for visual classifier is always v_dim for UPDN backbone
    classif_v_input_dim = train_dset.v_dim
    classif_v = {'input_dim': classif_v_input_dim, 'hidden_dims': [512], 'output_dim': train_dset.num_ans_candidates}
    
    # Wrap the base model with CFVQA or AdaptiveCFVQA if specified
    if args.mode == 'causal':
        # Create CFVQAModel
        print("Building CFVQAModel...")
        model = base_model.CFVQAModel(base_model_instance, train_dset.num_ans_candidates, classif_q, classif_v,
                                      fusion_mode='sum', is_va=True).cuda()
    elif args.mode == 'adaptive':
        # Create AdaptiveCFVQAModel
        print(f"Building AdaptiveCFVQAModel with {args.adaptive_method} uncertainty estimation...")
        model = base_model.AdaptiveCFVQAModel(
            base_model_instance, 
            train_dset.num_ans_candidates,
            classif_q, 
            classif_v,
            fusion_mode='sum', 
            is_va=True,
            adaptive_method=args.adaptive_method,
            uncertainty_hidden_dim=args.uncertainty_hidden_dim,
            alpha_reg_weight=args.alpha_reg_weight,
            alpha_uncertainty_sources=args.alpha_uncertainty_sources
        ).cuda()
    else:
        # Use the base UPDN model directly
        print("Building baseline model...")
        model = base_model_instance

    # --- Curriculum Learning: Pre-calculate uncertainty and sort train_dset if enabled ---
    if args.use_curriculum:
        difficulty_scores = {} # Store original index -> score

        if args.curriculum_metric == 'uncertainty_entropy':
            print("Pre-calculating uncertainty (entropy) for curriculum learning...")
            # We need the fully constructed model here
            model.eval() # Set model to evaluation mode
            # Use a temporary loader with original dataset order
            temp_loader = DataLoader(train_dset, args.batch_size, shuffle=False, num_workers=32)
            entry_idx = 0
            with torch.no_grad(): # Disable gradient calculation
                for v, q, _, _ in tqdm(temp_loader, desc="Pre-calc Entropy"): # Unpack 4 items
                    v = v.cuda()
                    q = q.cuda()
                    # Forward pass to get base model logits
                    # Handle different model structures consistently
                    if isinstance(model, base_model.BaseModel):
                        output = model(v, None, q, None, None)
                        logits_vq = output["logits"]
                    elif hasattr(model, 'base_model'): # For CFVQA and AdaptiveCFVQA
                        base_output = model.base_model(v, None, q, None, None)
                        logits_vq = base_output["logits"]
                    else:
                        raise TypeError("Model type not recognized for uncertainty pre-calculation.")
                    
                    # Calculate entropy
                    probs = torch.softmax(logits_vq, dim=-1)
                    entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) # [batch_size]
                    
                    # Store scores using original index
                    for i in range(entropy.size(0)):
                        # Assuming temp_loader preserves original order, entry_idx maps to original index
                        difficulty_scores[entry_idx] = entropy[i].item() # Higher entropy = harder
                        entry_idx += 1
            model.train() # Set model back to training mode
            print(f"Finished pre-calculating entropy for {len(difficulty_scores)} samples.")

        # Assign scores and sort
        print(f"Assigning and sorting training data based on curriculum metric: {args.curriculum_metric}")
        for i, entry in enumerate(train_dset.entries):
            if args.curriculum_metric == 'bias_max':
                # Difficulty score: Higher max bias -> easier sample -> lower score
                entry['difficulty_score'] = -np.max(entry['bias'])
            elif args.curriculum_metric == 'uncertainty_entropy':
                # Difficulty score: Higher entropy -> harder sample -> higher score
                # Use original index i to look up the pre-calculated score
                entry['difficulty_score'] = difficulty_scores.get(i, float('inf')) # Default high if missing
            else:
                raise NotImplementedError(f"Curriculum metric '{args.curriculum_metric}' not implemented.")

        # Sort entries by difficulty score (easier first)
        train_dset.entries.sort(key=lambda x: x['difficulty_score'])
        print("Training data sorted.")
    # ----------------------------------------------------------------------------------

    # Add the loss_fn based our arguments
    if args.mode == "none":
        model.debias_loss_fn = Plain()
    elif args.mode == "causal" or args.mode == "adaptive":
        print("Using counterfactual causal model")
    else:
        raise RuntimeError(args.mode)

    # Record the bias function we are using
    utils.create_dir(args.output)
    # with open(args.output + "/debias_objective.json", "w") as f:
    #     js = model.debias_loss_fn.to_json()
    #     json.dump(js, f, indent=2)

    model = model.cuda()
    batch_size = args.batch_size

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = True

    # The original version uses multiple workers, but that just seems slower on my setup
    train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=32)
    eval_loader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=32)

    print("Starting training...")
    train(model, train_dset, eval_dset, args) # Pass datasets and args


if __name__ == '__main__':
    main()
