import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from tqdm.auto import tqdm
from accelerate import Accelerator
from torch.utils.data import DataLoader
import wandb
from accelerate.logging import get_logger
import shutil
from torch.optim.lr_scheduler import OneCycleLR, StepLR
import torchvision.models as models
from typing import Dict, Optional
import numpy as np

from data import GlobDataset, ClevrTexDataset, CelebADataset

logger = get_logger(__name__)


class ResNet18Classifier(nn.Module):
    """ResNet18-based classifier for image classification tasks"""
    
    def __init__(self, num_classes: int, input_channels: int = 3, pretrained: bool = True):
        super().__init__()
        
        # Use torchvision's ResNet18
        self.backbone = models.resnet18(pretrained=pretrained)
        
        # Modify first conv layer if input_channels != 3
        if input_channels != 3:
            self.backbone.conv1 = nn.Conv2d(
                input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
            )
        
        # Replace the final fully connected layer
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        
    def forward(self, x):
        return self.backbone(x)


class MultiTaskClassifier(nn.Module):
    """Multi-task classifier for datasets with multiple labels (like CelebA)"""
    
    def __init__(self, num_attributes: int, input_channels: int = 3, pretrained: bool = True):
        super().__init__()
        
        # Use ResNet18 as backbone
        self.backbone = models.resnet18(pretrained=pretrained)
        
        # Modify first conv layer if needed
        if input_channels != 3:
            self.backbone.conv1 = nn.Conv2d(
                input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
            )
        
        # Remove the final FC layer
        self.feature_extractor = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # Add custom classification head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_attributes)
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)


def create_dataset(args, train: bool = True):
    """Create dataset based on dataset_type"""
    
    if args.dataset_type == "glob":
        data_portion = (0.0, args.train_split_portion) if train else (args.train_split_portion, 1.0)
        return GlobDataset(
            root=args.dataset_root,
            img_size=args.img_size,
            img_glob=args.img_glob.split(" ") if hasattr(args, 'img_glob') else ['**/*.png', '**/*.jpg'],
            data_portion=data_portion,
            random_flip=train and args.random_flip
        )
    
    elif args.dataset_type == "clevrtex":
        data_portion = (0.0, args.train_split_portion) if train else (args.train_split_portion, 1.0)
        return ClevrTexDataset(
            root=args.dataset_root,
            img_size=args.img_size,
            data_portion=data_portion,
            random_flip=train and args.random_flip
        )
    
    elif args.dataset_type == "celeba":
        data_portion = (0.0, args.train_split_portion) if train else (args.train_split_portion, 1.0)
        return CelebADataset(
            root=args.dataset_root,
            img_size=args.img_size,
            data_portion=data_portion
        )
    
    else:
        raise ValueError(f"Unknown dataset type: {args.dataset_type}")


def compute_accuracy(outputs, labels, task_type="binary"):
    """Compute accuracy for different task types"""
    
    if task_type == "binary" or task_type == "multilabel":
        # For binary/multilabel classification
        predictions = torch.sigmoid(outputs) > 0.5
        correct = (predictions == labels.bool()).float()
        return correct.mean().item()
    
    elif task_type == "multiclass":
        # For multiclass classification
        predictions = torch.argmax(outputs, dim=1)
        correct = (predictions == labels).float()
        return correct.mean().item()
    
    else:
        raise ValueError(f"Unknown task type: {task_type}")


def train_epoch(model, dataloader, optimizer, criterion, accelerator, task_type="binary"):
    """Train for one epoch"""
    model.train()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, disable=not accelerator.is_local_main_process)
    
    for batch in progress_bar:
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        
        # Forward pass
        outputs = model(pixel_values)
        
        # Compute loss
        if task_type == "multiclass":
            loss = criterion(outputs, labels.long())
        else:  # binary or multilabel
            loss = criterion(outputs, labels.float())
        
        # Backward pass
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        
        # Compute metrics
        with torch.no_grad():
            accuracy = compute_accuracy(outputs, labels, task_type)
            
        total_loss += loss.item()
        total_accuracy += accuracy
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{accuracy:.4f}'
        })
    
    return total_loss / num_batches, total_accuracy / num_batches


def validate_epoch(model, dataloader, criterion, accelerator, task_type="binary"):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process):
            pixel_values = batch["pixel_values"]
            labels = batch["labels"]
            
            outputs = model(pixel_values)
            
            if task_type == "multiclass":
                loss = criterion(outputs, labels.long())
            else:
                loss = criterion(outputs, labels.float())
            
            accuracy = compute_accuracy(outputs, labels, task_type)
            
            total_loss += loss.item()
            total_accuracy += accuracy
            num_batches += 1
    
    return total_loss / num_batches, total_accuracy / num_batches


def main(args):
    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        log_with="wandb" if args.use_wandb else None,
    )
    
    # Setup logging
    if accelerator.is_main_process and args.use_wandb:
        wandb.init(
            project=args.wandb_project,
            name=args.experiment_name,
            config=vars(args)
        )
    
    # Create datasets
    train_dataset = create_dataset(args, train=True)
    val_dataset = create_dataset(args, train=False)
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Determine number of classes and task type
    sample_batch = next(iter(train_dataloader))
    sample_labels = sample_batch["labels"]
    
    if args.dataset_type == "celeba":
        num_classes = sample_labels.shape[1]  # Number of attributes
        task_type = "multilabel"
        model = MultiTaskClassifier(num_classes, pretrained=args.pretrained)
        criterion = nn.BCEWithLogitsLoss()
    elif args.dataset_type == "clevrtex":
        # Assuming binary classification for ClevrTex for now
        num_classes = 1 if len(sample_labels.shape) == 1 else sample_labels.shape[1]
        task_type = "binary" if num_classes == 1 else "multilabel"
        model = ResNet18Classifier(num_classes, pretrained=args.pretrained) if task_type == "binary" else MultiTaskClassifier(num_classes, pretrained=args.pretrained)
        criterion = nn.BCEWithLogitsLoss()
    else:  # glob dataset - assume it's a classification task
        # Try to infer from labels
        unique_labels = torch.unique(sample_labels)
        if len(unique_labels) <= 2 and all(label in [0, 1] for label in unique_labels):
            task_type = "binary"
            num_classes = 1
            criterion = nn.BCEWithLogitsLoss()
        else:
            task_type = "multiclass"
            num_classes = len(unique_labels)
            criterion = nn.CrossEntropyLoss()
        model = ResNet18Classifier(num_classes, pretrained=args.pretrained)
    
    logger.info(f"Task type: {task_type}, Number of classes: {num_classes}")
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay
    )
    
    # Setup scheduler
    if args.scheduler == "onecycle":
        scheduler = OneCycleLR(
            optimizer,
            max_lr=args.learning_rate,
            epochs=args.num_epochs,
            steps_per_epoch=len(train_dataloader)
        )
    elif args.scheduler == "step":
        scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    else:
        scheduler = None
    
    # Prepare everything with accelerator
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader
    )
    
    if scheduler is not None:
        scheduler = accelerator.prepare(scheduler)
    
    # Training loop
    best_val_acc = 0.0
    
    for epoch in range(args.num_epochs):
        logger.info(f"Epoch {epoch + 1}/{args.num_epochs}")
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_dataloader, optimizer, criterion, accelerator, task_type
        )
        
        # Validate
        val_loss, val_acc = validate_epoch(
            model, val_dataloader, criterion, accelerator, task_type
        )
        
        # Step scheduler
        if scheduler is not None and args.scheduler == "step":
            scheduler.step()
        elif scheduler is not None and args.scheduler == "onecycle":
            scheduler.step()
        
        # Log metrics
        logs = {
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "lr": optimizer.param_groups[0]["lr"]
        }
        
        if accelerator.is_main_process:
            logger.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            logger.info(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            if args.use_wandb:
                wandb.log(logs)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if accelerator.is_main_process:
                # Save model checkpoint
                os.makedirs(args.output_dir, exist_ok=True)
                accelerator.save_state(os.path.join(args.output_dir, "best_model"))
                logger.info(f"Saved best model with validation accuracy: {best_val_acc:.4f}")
    
    # Save final model
    if accelerator.is_main_process:
        accelerator.save_state(os.path.join(args.output_dir, "final_model"))
        logger.info("Training completed!")
        logger.info(f"Best validation accuracy: {best_val_acc:.4f}")


def parse_args():
    parser = argparse.ArgumentParser(description="Train ResNet18 classifier")
    
    # Dataset arguments
    parser.add_argument("--dataset_type", type=str, choices=["glob", "clevrtex", "celeba"], 
                       default="glob", help="Type of dataset to use")
    parser.add_argument("--dataset_root", type=str, required=True, help="Path to dataset")
    parser.add_argument("--img_glob", type=str, default="**/*.png **/*.jpg", 
                       help="Glob patterns for images (space-separated)")
    parser.add_argument("--img_size", type=int, default=224, help="Image size for resizing")
    parser.add_argument("--train_split_portion", type=float, default=0.8, 
                       help="Portion of data to use for training")
    parser.add_argument("--random_flip", action="store_true", help="Apply random horizontal flip")
    
    # Training arguments
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=50, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-4, help="Weight decay")
    parser.add_argument("--scheduler", type=str, choices=["onecycle", "step", "none"], 
                       default="step", help="Learning rate scheduler")
    parser.add_argument("--mixed_precision", type=str, default="fp16", 
                       choices=["no", "fp16", "bf16"], help="Mixed precision training")
    
    # Model arguments
    parser.add_argument("--pretrained", action="store_true", default=True, 
                       help="Use pretrained ResNet18")
    
    # Output arguments
    parser.add_argument("--output_dir", type=str, default="./classifier_output", 
                       help="Output directory for checkpoints")
    
    # Logging arguments
    parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases logging")
    parser.add_argument("--wandb_project", type=str, default="resnet18-classifier", 
                       help="Wandb project name")
    parser.add_argument("--experiment_name", type=str, default=None, 
                       help="Experiment name for wandb")
    
    # System arguments
    parser.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers")
    
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    main(args)