import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import os
import json
import argparse
from tqdm import tqdm
import logging
from datetime import datetime
from infoclip_plus import InfoCLIPPP
from datasets import get_dataset
from tasks import run_task
from collections import defaultdict


# ------------------------------
# Setup Utilities
# ------------------------------
def setup_logging(save_dir):
    os.makedirs(save_dir, exist_ok=True)
    log_path = os.path.join(save_dir, f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[logging.FileHandler(log_path), logging.StreamHandler()]
    )
    return logging.getLogger(__name__)


def load_config(config_path):
    """Load config with paper's default parameters"""
    with open(config_path, 'r') as f:
        config = json.load(f)
    # Enforce paper's hyperparameters (§4.1)
    config.setdefault('img_size', 224)
    config.setdefault('d_model', 768)
    config.setdefault('patch_sizes', [16, 32])
    config.setdefault('num_epochs', 30)
    config.setdefault('lr', 5e-5)
    config.setdefault('weight_decay', 0.05)
    config.setdefault('batch_size', 128)  # Adjust based on GPU memory
    config.setdefault('num_workers', 4)
    config.setdefault('noise_level', 0.0)
    config.setdefault('save_dir', './outputs')
    return config


def get_optimizer(model, config):
    """AdamW + CosineAnnealing (§4.1)"""
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(
        params, lr=config['lr'], weight_decay=config['weight_decay'],
        betas=(0.9, 0.98)  # Paper's optimizer betas
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config['num_epochs'], eta_min=1e-6
    )
    return optimizer, scheduler


# ------------------------------
# Training Loop (§4.1)
# ------------------------------
def train_one_epoch(model, train_loader, optimizer, device, epoch, logger, config):
    model.train()
    total_loss = 0.0
    loss_dict = defaultdict(float)
    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{config['num_epochs']}")

    for batch in pbar:
        # Unpack COCO batch: (img_tensor, caption, img_id)
        img_tensor, captions, _ = batch
        img_tensor = img_tensor.to(device)

        # Forward pass (with DAR annealing)
        _, _, loss, batch_loss = model(
            images=img_tensor, texts=captions,
            noise_level=config['noise_level'], training_epoch=epoch
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        total_loss += loss.item()
        for k, v in batch_loss.items():
            loss_dict[k] += v

        # Update progress bar
        pbar.set_postfix({
            'Batch Loss': f"{loss.item():.4f}",
            'LR': f"{optimizer.param_groups[0]['lr']:.6f}"
        })

    # Average metrics
    avg_loss = total_loss / len(train_loader)
    for k in loss_dict:
        loss_dict[k] /= len(train_loader)

    # Log results
    logger.info(f"Epoch {epoch + 1} Summary:")
    logger.info(f"  Average Total Loss: {avg_loss:.4f}")
    for k, v in loss_dict.items():
        logger.info(f"  {k}: {v:.4f}")
    return avg_loss, loss_dict


# ------------------------------
# Main Function
# ------------------------------
def main(config, args):
    # Initialize
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    save_dir = os.path.join(config['save_dir'], datetime.now().strftime('%Y%m%d_%H%M%S'))
    logger = setup_logging(save_dir)

    # Save config
    with open(os.path.join(save_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=4)

    logger.info(f"Using device: {device}")
    logger.info(f"Save directory: {save_dir}")
    logger.info(f"Config: {json.dumps(config, indent=2)}")

    # Build model
    logger.info("Initializing InfoCLIP++ model...")
    model = InfoCLIPPP(
        img_size=config['img_size'],
        d_model=config['d_model'],
        patch_sizes=config['patch_sizes'],
        max_epochs=config['num_epochs']
    ).to(device)

    # Resume from checkpoint
    start_epoch = 0
    if args.resume:
        if not os.path.exists(args.resume):
            raise FileNotFoundError(f"Checkpoint {args.resume} not found")
        logger.info(f"Resuming from checkpoint: {args.resume}")
        checkpoint = torch.load(args.resume, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1

    # Build datasets (§4.1)
    logger.info("Loading datasets...")
    # Train: COCO Karpathy train split
    train_dataset = get_dataset(
        'coco',
        root=config['coco_root'],
        ann_file=config['coco_ann_file'],
        karpathy_split_file=config['karpathy_split_file'],
        split='train',
        img_size=config['img_size']
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )

    # Val: ImageNet, COCO Karpathy val, ADE20K
    val_datasets = {
        'imagenet': get_dataset(
            'imagenet',
            root=config['imagenet_root'],
            split='val',
            img_size=config['img_size'],
            class_names_path=config['imagenet_classes_path']
        ),
        'coco': get_dataset(
            'coco',
            root=config['coco_root'],
            ann_file=config['coco_ann_file'],
            karpathy_split_file=config['karpathy_split_file'],
            split='val',
            img_size=config['img_size']
        ),
        'ade20k': get_dataset(
            'ade20k',
            root=config['ade20k_root'],
            ann_file=config['ade20k_ann_file'],
            split='val',
            img_size=config['img_size']
        )
    }

    val_loaders = {
        name: DataLoader(
            ds,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            pin_memory=True
        ) for name, ds in val_datasets.items()
    }

    # Load class names for zero-shot
    with open(config['imagenet_classes_path'], 'r') as f:
        imagenet_classes = [line.strip() for line in f.readlines()]
    # Medical prompts (§4.2.1)
    medical_prompts = {
        'MedMNIST2D': lambda cls: f"a CT scan of a {cls}"
    }

    # Optimizer
    optimizer, scheduler = get_optimizer(model, config)
    if args.resume and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # Evaluation-only mode
    if args.eval_only:
        logger.info("Running evaluation only (§4.2)...")
        eval_metrics = {
            'zero_shot': run_task(
                'zero_shot',
                model,
                val_loaders['imagenet'],
                device,
                class_names=imagenet_classes,
                domain_prompts=medical_prompts.get('ImageNet', None),
                save_path=save_dir
            ),
            'cross_modal': run_task(
                'cross_modal',
                model,
                val_loaders['coco'],
                device,
                save_path=save_dir
            ),
            'fine_grained': run_task(
                'fine_grained',
                model,
                val_loaders['ade20k'],
                device,
                save_path=save_dir
            )
        }
        with open(os.path.join(save_dir, 'eval_metrics.json'), 'w') as f:
            json.dump(eval_metrics, f, indent=4)
        logger.info("Evaluation completed.")
        return

    # Training loop (§4.1: 30 epochs, val every 5 epochs)
    logger.info("Starting training (§4.1)...")
    best_val_score = -float('inf')

    for epoch in range(start_epoch, config['num_epochs']):
        # Train
        train_loss, _ = train_one_epoch(model, train_loader, optimizer, device, epoch, logger, config)
        scheduler.step()

        # Validate every 5 epochs or last epoch
        if (epoch + 1) % 5 == 0 or epoch == config['num_epochs'] - 1:
            logger.info("=" * 60)
            logger.info(f"Validation after Epoch {epoch + 1} (§4.2)...")

            # Run all tasks
            val_metrics = {
                'zero_shot': run_task(
                    'zero_shot',
                    model,
                    val_loaders['imagenet'],
                    device,
                    class_names=imagenet_classes,
                    save_path=save_dir
                ),
                'cross_modal': run_task(
                    'cross_modal',
                    model,
                    val_loaders['coco'],
                    device,
                    save_path=save_dir
                ),
                'fine_grained': run_task(
                    'fine_grained',
                    model,
                    val_loaders['ade20k'],
                    device,
                    save_path=save_dir
                )
            }

            # Compute validation score (average of key metrics)
            val_score = (
                                val_metrics['zero_shot']['top1_accuracy'] +
                                val_metrics['cross_modal']['image_to_text']['R@1'] +
                                val_metrics['cross_modal']['text_to_image']['R@1'] +
                                val_metrics['fine_grained']['mIoU']
                        ) / 4

            # Save best model
            if val_score > best_val_score:
                best_val_score = val_score
                checkpoint_path = os.path.join(save_dir, f"best_epoch_{epoch + 1}.pth")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_score': val_score,
                    'metrics': val_metrics
                }, checkpoint_path)
                logger.info(f"Saved best model to {checkpoint_path} (Score: {val_score:.4f})")

            logger.info("=" * 60)

        # Save latest model
        latest_path = os.path.join(save_dir, "latest_model.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, latest_path)

    logger.info(f"Training completed! Best validation score: {best_val_score:.4f}")


# ------------------------------
# Entry Point
# ------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="InfoCLIP++ Training & Evaluation")
    parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
    parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint for resuming")
    parser.add_argument("--eval-only", action="store_true", help="Run evaluation without training")
    args = parser.parse_args()

    # Load config and run
    config = load_config(args.config)
    main(config, args)