import os
import argparse
import torch
import numpy as np
from PIL import Image
import json
import matplotlib.pyplot as plt
from datetime import datetime
from typing import List, Dict, Any, Optional
from tqdm import tqdm

from src.utils import (
    load_models,
    generate_counterfactuals
)
from src.data.dataset import CelebADataset
from src.models.classifier import FaceAttributeClassifier


def parse_args():
    parser = argparse.ArgumentParser(description="Generate counterfactual explanations")
    
    # Model paths
    parser.add_argument("--checkpoint_path", type=str, required=True, 
                        help="Path to the diffusion model checkpoint")
    parser.add_argument("--classifier_path", type=str, required=True,
                        help="Path to the classifier model checkpoint")
    parser.add_argument("--use_slots", action="store_true", 
                        help="Whether to use slot models")
    
    # Dataset settings
    parser.add_argument("--data_path", type=str, required=True,
                        help="Path to CelebA dataset")
    parser.add_argument("--img_size", type=int, default=256,
                        help="Image size for processing")
    parser.add_argument("--split", type=str, default="valid",
                        choices=["train", "valid", "test"],
                        help="Dataset split to use")
    parser.add_argument("--num_images", type=int, default=16,
                        help="Number of counterfactual examples to generate per attribute")
    
    # Diffusion parameters
    parser.add_argument("--guidance_scale", type=float, default=1.0,
                        help="Classifier guidance scale")
    parser.add_argument("--conditioning_scale", type=float, default=1.0,
                        help="Slot conditioning scale (only for slot models)")
    parser.add_argument("--target_step", type=int, default=500,
                        help="Target diffusion timestep for reconstruction")
    parser.add_argument("--num_steps", type=int, default=200,
                        help="Number of diffusion sampling steps")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="Batch size for processing")
    parser.add_argument("--mask_attributes", action="store_true", default=True,
                        help="Whether to mask attributes for focused editing")
    
    # Output settings
    parser.add_argument("--output_dir", type=str, default="counterfactuals_output",
                        help="Directory to save results")
    parser.add_argument("--save_all", action="store_true", default=True,
                        help="Whether to save all counterfactuals or just those with correct predictions")
    parser.add_argument("--calculate_fid", action="store_true", default=True,
                        help="Whether to calculate FID score between original and counterfactual images")
    
    # Hardware settings
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to run generation on")
    parser.add_argument("--weight_dtype", type=str, default="float32",
                        choices=["float32", "float16", "bfloat16"],
                        help="Data type for model weights")
    
    # Specific attributes to focus on (optional)
    parser.add_argument("--attributes", nargs='+', type=str, default=None,
                        help="Specific attributes to modify (if not specified, uses default preferred list)")
    
    return parser.parse_args()


def get_weight_dtype(dtype_str):
    if dtype_str == "float32":
        return torch.float32
    elif dtype_str == "float16":
        return torch.float16
    elif dtype_str == "bfloat16":
        return torch.bfloat16
    else:
        raise ValueError(f"Unsupported dtype: {dtype_str}")


def save_counterfactual_images(results, metrics, output_dir):
    """Save counterfactual images to the specified output directory structure.
    
    Args:
        results: Dictionary mapping attribute names to lists of tuples (original_image, counterfactual_image, change_info)
        metrics: Dictionary containing success/failure statistics for each attribute
        output_dir: Directory to save results
    """
    # Create the main output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Create a mapping of all saved images for stats tracking
    saved_images = {}
    
    # Create a dictionary to keep track of original images (by tensor hash) and their directory
    original_image_dirs = {}
    
    # Also create a detailed stats dict for storing per-image, per-attribute prediction info
    image_stats = {}
    
    # Process each attribute's results
    for attr, attr_results in results.items():
        for orig_img, cf_img, change_info in attr_results:
            # Create a tensor hash for the original image to identify duplicates
            if isinstance(orig_img, torch.Tensor):
                # Use a simple hash of the tensor data to identify unique images
                img_hash = hash(orig_img.cpu().numpy().tobytes())
            else:
                # For PIL images, convert to numpy and hash
                img_hash = hash(np.array(orig_img).tobytes())
                
            # Check if we've seen this image before
            if img_hash in original_image_dirs:
                img_dir = original_image_dirs[img_hash]
            else:
                # This is a new image, assign the next available directory
                image_counter = len(original_image_dirs) + 1
                img_dir = os.path.join(output_dir, f"image_{image_counter}")
                os.makedirs(img_dir, exist_ok=True)
                original_image_dirs[img_hash] = img_dir
                
                # Save the original image (only for new images)
                orig_path = os.path.join(img_dir, "original.png")
                # Convert tensor to PIL image
                if isinstance(orig_img, torch.Tensor):
                    orig_img_np = orig_img.cpu().permute(1, 2, 0).numpy()
                    orig_img_pil = Image.fromarray((orig_img_np * 255).astype(np.uint8))
                    orig_img_pil.save(orig_path)
                else:
                    orig_img.save(orig_path)
                
                # Initialize stats for this image
                image_stats[img_dir] = {}
            
            # Track detailed stats for this image and attribute
            if change_info:
                image_stats[img_dir][attr] = {
                    'gt_label': change_info.get('gt_label'),
                    'orig_prediction': change_info.get('orig_prediction'),
                    'orig_probability': change_info.get('orig_probability'),  # Store probability
                    'orig_logit': change_info.get('orig_logit'),  # Store logit
                    'orig_matches_gt': change_info.get('prediction_matches', False),
                    'target_label': change_info.get('target_label'),
                    'cf_prediction': change_info.get('cf_prediction'),
                    'cf_probability': change_info.get('cf_probability'),  # Store probability
                    'cf_logit': change_info.get('cf_logit'),  # Store logit
                    'cf_matches_target': change_info.get('target_success', False)
                }
                
                # Store full predictions for all attributes if available
                if 'orig_all_preds' in change_info:
                    image_stats[img_dir][attr]['all_attributes'] = {
                        'orig_all_preds': change_info.get('orig_all_preds'),
                        'orig_all_probs': change_info.get('orig_all_probs'),  # Store all probabilities
                        'orig_all_logits': change_info.get('orig_all_logits'),  # Store all logits
                        'cf_all_preds': change_info.get('cf_all_preds'),
                        'cf_all_probs': change_info.get('cf_all_probs'),  # Store all probabilities
                        'cf_all_logits': change_info.get('cf_all_logits'),  # Store all logits
                        'gt_all_labels': change_info.get('gt_all_labels')
                    }
            
            # Unique identifier for this counterfactual
            is_success = 1 if change_info.get('target_success', False) else 0
            cf_filename = f"{is_success}_{change_info['attr_name']}.png"
            
            # Save the counterfactual image
            cf_path = os.path.join(img_dir, cf_filename)
            if isinstance(cf_img, torch.Tensor):
                cf_img_np = cf_img.cpu().permute(1, 2, 0).numpy()
                cf_img_pil = Image.fromarray((cf_img_np * 255).astype(np.uint8))
                cf_img_pil.save(cf_path)
            else:
                cf_img.save(cf_path)
            
            # Track the saved image for stats
            if img_dir not in saved_images:
                saved_images[img_dir] = []
            saved_images[img_dir].append({
                'attr_name': change_info['attr_name'],
                'original_value': change_info['original'],
                'modified_value': change_info['modified'],
                'gt_label': change_info.get('gt_label'),
                'orig_prediction': change_info.get('orig_prediction'),
                'orig_probability': change_info.get('orig_probability', 0.0),  # Include probability
                'orig_logit': change_info.get('orig_logit', 0.0),  # Include logit
                'orig_matches_gt': change_info.get('prediction_matches', False),
                'target_label': change_info.get('target_label'),
                'cf_prediction': change_info.get('cf_prediction'),
                'cf_probability': change_info.get('cf_probability', 0.0),  # Include probability
                'cf_logit': change_info.get('cf_logit', 0.0),  # Include logit
                'cf_matches_target': change_info.get('target_success', False),
                'success': change_info.get('target_success', False),
                'filename': cf_filename
            })
    
    # Save detailed image stats as JSON
    stats_json_path = os.path.join(output_dir, "image_stats.json")
    with open(stats_json_path, 'w') as f:
        json.dump(image_stats, f, indent=2)
    
    return saved_images, image_stats


def save_statistics(args, metrics, fid_scores, saved_images, output_dir, image_stats=None):
    """Save statistics and hyperparameters to a text file.
    
    Args:
        args: Command-line arguments
        metrics: Dictionary containing success/failure statistics for each attribute
        fid_scores: Dictionary containing overall FID score
        saved_images: Mapping of saved images for detailed stats
        output_dir: Directory to save results
        image_stats: Detailed per-image, per-attribute prediction statistics
    """
    stats_path = os.path.join(output_dir, "stats.txt")
    
    with open(stats_path, 'w') as f:
        # Write timestamp
        f.write(f"Counterfactual Generation Report\n")
        f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        # Write hyperparameters
        f.write("=== Hyperparameters ===\n")
        f.write(f"Model checkpoint: {args.checkpoint_path}\n")
        f.write(f"Classifier: {args.classifier_path}\n")
        f.write(f"Using slots: {args.use_slots}\n")
        f.write(f"Guidance scale: {args.guidance_scale}\n")
        if args.use_slots:
            f.write(f"Slot conditioning scale: {args.conditioning_scale}\n")
        f.write(f"Target diffusion step: {args.target_step}\n")
        f.write(f"Number of sampling steps: {args.num_steps}\n")
        f.write(f"Masking attributes: {args.mask_attributes}\n")
        f.write(f"Weight data type: {args.weight_dtype}\n\n")
        
        # Write overall statistics
        f.write("=== Overall Statistics ===\n")
        
        # Calculate overall success and failure rates correctly
        total_success = sum(m['success_count'] for m in metrics.values())
        total_attempts = sum(m['total_attempts'] for m in metrics.values())
        overall_success_rate = (total_success / total_attempts) * 100 if total_attempts > 0 else 0
        
        total_failure = sum(m['failure_count'] for m in metrics.values())
        total_unchanged = sum(m['unchanged_count'] for m in metrics.values())
        overall_failure_rate = (total_failure / total_unchanged) * 100 if total_unchanged > 0 else 0
        
        f.write(f"Overall success rate: {overall_success_rate:.2f}% ({total_success}/{total_attempts})\n")
        f.write(f"Overall failure rate: {overall_failure_rate:.2f}% ({total_failure}/{total_unchanged})\n")
        
        # Add overall FID score if available
        if 'overall' in fid_scores and not np.isnan(fid_scores['overall']):
            f.write(f"Overall FID score: {fid_scores['overall']:.4f}\n")
        
        f.write("\n")
        
        # Write per-attribute statistics
        f.write("=== Per-Attribute Statistics ===\n")
        for attr, attr_metrics in metrics.items():
            success_rate = (attr_metrics['success_count'] / attr_metrics['total_attempts']) * 100 if attr_metrics['total_attempts'] > 0 else 0
            failure_rate = (attr_metrics['failure_count'] / attr_metrics['unchanged_count']) * 100 if attr_metrics['unchanged_count'] > 0 else 0
            
            f.write(f"\nAttribute: {attr}\n")
            f.write(f"  Success rate: {success_rate:.2f}% ({attr_metrics['success_count']}/{attr_metrics['total_attempts']})\n")
            f.write(f"  Failure rate: {failure_rate:.2f}% ({attr_metrics['failure_count']}/{attr_metrics['unchanged_count']})\n")
        
        # Write details of generated counterfactuals
        f.write("\n\n=== Generated Counterfactuals ===\n")
        for img_dir, counterfactuals in saved_images.items():
            dir_name = os.path.basename(img_dir)
            f.write(f"\n{dir_name}:\n")
            for cf in counterfactuals:
                # Enhanced information with detailed prediction status
                orig_pred_status = "MATCH" if cf['orig_matches_gt'] else "MISMATCH"
                cf_pred_status = "SUCCESS" if cf['cf_matches_target'] else "FAILED"
                
                f.write(f"  {cf['filename']} - {cf['attr_name']}: {cf['original_value']} → {cf['modified_value']}\n")
                f.write(f"    GT: {cf['gt_label']} | Orig Pred: {cf['orig_prediction']} ({orig_pred_status}, prob: {cf['orig_probability']:.3f}) | "
                       f"Target: {cf['target_label']} | CF Pred: {cf['cf_prediction']} ({cf_pred_status}, prob: {cf['cf_probability']:.3f})\n")


def main():
    # Parse command line arguments
    torch.set_float32_matmul_precision('high')
    args = parse_args()
    
    # Set device and dtype
    device = args.device if torch.cuda.is_available() else "cpu"
    weight_dtype = get_weight_dtype(args.weight_dtype)
    
    print(f"Using device: {device}, Weight dtype: {args.weight_dtype}")
    
    # Load dataset
    print("Loading dataset...")
    dataset = CelebADataset(
        root=args.data_path,
        img_size=256,
        data_portion=(0.9, 1),
    )
    
    # Load classifier
    print("Loading classifier...")
    classifier = FaceAttributeClassifier(
        input_channels=3,
        num_attributes=40,  # CelebA has 40 attributes
        image_size=256,
    )
    classifier.load_state_dict(torch.load(args.classifier_path, map_location=device)['model_state_dict'])
    classifier.to(device).eval()
    
    # Load diffusion models
    print("Loading diffusion models...")
    vae, unet, scheduler, backbone, slot_attn = load_models(
        args.checkpoint_path,
        args.use_slots,
        device=device,
        weight_dtype=weight_dtype
    )
    
    # Move UNet to device and dtype
    unet = unet.to(device=device, dtype=weight_dtype)
    
    # Generate counterfactuals
    print("\nGenerating counterfactual explanations...")
    results, metrics, fid_scores = generate_counterfactuals(
        val_dataset=dataset,
        unet=unet,
        vae=vae,
        scheduler=scheduler,
        classifier=classifier,
        target_step=args.target_step,
        num_inference_steps=args.num_steps,
        guidance_scale=args.guidance_scale,
        conditioning_scale=args.conditioning_scale,
        backbone=backbone,
        slot_attn=slot_attn,
        weight_dtype=weight_dtype,
        n_images=args.num_images,
        batch_size=args.batch_size,
        preferred_attrs=args.attributes,
        device=device,
        visualize_attn=False,  # No need for visualization for bulk generation
        mask_attributes=args.mask_attributes,
        save_all_modifications=args.save_all,
        calculate_fid=args.calculate_fid
    )
    
    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.output_dir, f"counterfactuals_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Save images with enhanced statistics
    print("\nSaving counterfactual images...")
    saved_images, image_stats = save_counterfactual_images(results, metrics, output_dir)
    
    # Save statistics with enhanced information
    print("Saving statistics...")
    save_statistics(args, metrics, fid_scores, saved_images, output_dir, image_stats)
    
    print(f"\nCounterfactual generation complete! Results saved to: {output_dir}")
    print(f"Detailed image statistics saved to: {os.path.join(output_dir, 'image_stats.json')}")


if __name__ == "__main__":
    main()
