"""
COVID-19 X-ray Dataset Processing Module

This module provides a complete pipeline for:
1. Data loading and preprocessing
2. Model fine-tuning (Inception v3)
3. Probability and label extraction for calibration and test sets

Usage:
    from data.Covid_data import get_covid_data
    cal_probs, cal_labels, test_probs, test_labels = get_covid_data()
"""

import os
import glob
import random
import numpy as np
from PIL import Image
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models
from torchvision.models import Inception_V3_Weights

from sklearn.model_selection import train_test_split


class CovidDataset(Dataset):
    """Custom Dataset for COVID-19 chest X-ray images."""
    
    def __init__(self, image_paths: List[str], labels: List[int], transform=None):
        """
        Args:
            image_paths: List of image file paths
            labels: List of corresponding labels
            transform: Optional transform to be applied on images
        """
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        label = self.labels[idx]
        return image, label


class CovidDataProcessor:
    """Handles data loading, preprocessing, and splitting."""
    
    # Class-to-index mapping
    CLASS_TO_IDX = {
        "Normal": 0,
        "Viral Pneumonia": 1,
        "COVID": 2,
        "Lung_Opacity": 3
    }
    
    def __init__(self, data_dir: str, seed: int = 42):
        """
        Args:
            data_dir: Root directory of COVID-19 radiography dataset
            seed: Random seed for reproducibility
        """
        self.data_dir = data_dir
        self.seed = seed
        self._set_seed()
        
    def _set_seed(self):
        """Set random seeds for reproducibility."""
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    def load_data(self) -> Tuple[List[str], List[int]]:
        """
        Load all image paths and labels from the dataset.
        
        Returns:
            Tuple of (image_paths, labels)
        """
        all_data = []
        
        for class_name, label_idx in self.CLASS_TO_IDX.items():
            class_folder = os.path.join(self.data_dir, class_name, "masks")
            image_files = glob.glob(os.path.join(class_folder, "*.png"))
            
            for img_path in image_files:
                all_data.append((img_path, label_idx))
        
        print(f"Total images loaded: {len(all_data)}")
        return all_data
    
    def split_data(self, all_data: List[Tuple[str, int]], 
                   cal_size: float = 0.10, test_size: float = 0.20) -> dict:
        """
        Split data into train, calibration, and test sets.
        
        Args:
            all_data: List of (image_path, label) tuples
            cal_size: Proportion of data for calibration
            test_size: Proportion of data for testing
            
        Returns:
            Dictionary containing train/cal/test splits
        """
        # Shuffle data
        random.shuffle(all_data)
        
        # Separate paths and labels
        X = [item[0] for item in all_data]
        y = [item[1] for item in all_data]
        
        # First split: train vs. (cal + test)
        train_size = 1.0 - (cal_size + test_size)
        X_train, X_rem, y_train, y_rem = train_test_split(
            X, y, test_size=(cal_size + test_size), 
            random_state=self.seed, stratify=y
        )
        
        # Second split: cal vs. test
        cal_ratio = cal_size / (cal_size + test_size)
        X_cal, X_test, y_cal, y_test = train_test_split(
            X_rem, y_rem, test_size=(1 - cal_ratio),
            random_state=self.seed, stratify=y_rem
        )
        
        print(f"Train size: {len(X_train)}")
        print(f"Calibration size: {len(X_cal)}")
        print(f"Test size: {len(X_test)}")
        
        return {
            'train': (X_train, y_train),
            'cal': (X_cal, y_cal),
            'test': (X_test, y_test)
        }
    
    def get_transforms(self) -> Tuple[T.Compose, T.Compose]:
        """
        Get training and evaluation transforms.
        
        Returns:
            Tuple of (train_transform, eval_transform)
        """
        train_transform = T.Compose([
            T.Resize((299, 299)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomRotation(degrees=15),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])
        
        eval_transform = T.Compose([
            T.Resize((299, 299)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
        ])
        
        return train_transform, eval_transform
    
    def create_dataloaders(self, splits: dict, batch_size: int = 32, 
                          num_workers: int = 2) -> dict:
        """
        Create DataLoaders for train, calibration, and test sets.
        
        Args:
            splits: Dictionary containing data splits
            batch_size: Batch size for DataLoaders
            num_workers: Number of worker processes for data loading
            
        Returns:
            Dictionary of DataLoaders
        """
        train_transform, eval_transform = self.get_transforms()
        
        # Create datasets
        train_dataset = CovidDataset(*splits['train'], transform=train_transform)
        cal_dataset = CovidDataset(*splits['cal'], transform=eval_transform)
        test_dataset = CovidDataset(*splits['test'], transform=eval_transform)
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, 
            shuffle=True, num_workers=num_workers
        )
        cal_loader = DataLoader(
            cal_dataset, batch_size=batch_size, 
            shuffle=False, num_workers=num_workers
        )
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, 
            shuffle=False, num_workers=num_workers
        )
        
        print(f"Train batches: {len(train_loader)}")
        print(f"Cal batches: {len(cal_loader)}")
        print(f"Test batches: {len(test_loader)}")
        
        return {
            'train': train_loader,
            'cal': cal_loader,
            'test': test_loader
        }


class InceptionV3Trainer:
    """Handles model creation, training, and evaluation."""
    
    def __init__(self, num_classes: int = 4, device: str = None):
        """
        Args:
            num_classes: Number of output classes
            device: Device to use for training ('cuda' or 'cpu')
        """
        self.num_classes = num_classes
        self.device = device if device else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        self.model = None
        self.criterion = None
        self.optimizer = None
        
    def create_model(self) -> nn.Module:
        """
        Create and configure Inception v3 model.
        
        Returns:
            Configured model
        """
        # Load pretrained Inception v3 (using new weights API)
        model = models.inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True)
        
        # Replace final layers
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, self.num_classes)
        model.AuxLogits.fc = nn.Linear(
            model.AuxLogits.fc.in_features, self.num_classes
        )
        
        # Freeze all parameters except final fc layer (matching notebook Cell 16)
        for param in model.parameters():
            param.requires_grad = False
        for param in model.fc.parameters():
            param.requires_grad = True
        # Note: AuxLogits.fc is replaced but remains frozen (only fc is trained)
        
        model = model.to(self.device)
        self.model = model
        
        # Setup loss and optimizer (matching notebook Cell 17)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(model.fc.parameters(), lr=1e-4)
        
        print(f"Model created and moved to {self.device}")
        return model
    
    def train(self, train_loader: DataLoader, num_epochs: int = 10) -> None:
        """
        Train the model.
        
        Args:
            train_loader: DataLoader for training data
            num_epochs: Number of training epochs
        """
        if self.model is None:
            raise ValueError("Model not created. Call create_model() first.")
        
        print(f"\nStarting training for {num_epochs} epochs...")
        self.model.train()
        
        for epoch in range(num_epochs):
            running_loss = 0.0
            correct = 0
            total = 0
            
            for images, labels in train_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                self.optimizer.zero_grad()
                
                # Forward pass with auxiliary outputs
                outputs, aux_outputs = self.model(images)
                
                # Compute combined loss
                loss1 = self.criterion(outputs, labels)
                loss2 = self.criterion(aux_outputs, labels)
                loss = loss1 + 0.4 * loss2
                
                # Backward pass
                loss.backward()
                self.optimizer.step()
                
                # Statistics
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
            
            epoch_loss = running_loss / total
            epoch_acc = correct / total * 100
            print(f"Epoch {epoch+1}/{num_epochs} - "
                  f"Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")
    
    def extract_probabilities(self, data_loader: DataLoader) -> Tuple[np.ndarray, np.ndarray]:
        """
        Extract softmax probabilities and labels from data.
        
        Args:
            data_loader: DataLoader for the dataset
            
        Returns:
            Tuple of (probabilities, labels) as numpy arrays
        """
        if self.model is None:
            raise ValueError("Model not created. Call create_model() first.")
        
        self.model.eval()
        probs_list = []
        labels_list = []
        
        with torch.no_grad():
            for images, labels in data_loader:
                images = images.to(self.device)
                
                # Forward pass (only main output during eval)
                outputs = self.model(images)
                
                # Convert logits to probabilities
                probs = torch.softmax(outputs, dim=1)
                
                # Move to CPU and store
                probs_list.append(probs.cpu().numpy())
                labels_list.append(labels.cpu().numpy())
        
        # Concatenate all batches
        probs_arr = np.concatenate(probs_list, axis=0)
        labels_arr = np.concatenate(labels_list, axis=0)
        
        return probs_arr, labels_arr
    
    def save_model(self, path: str) -> None:
        """Save model state dict."""
        if self.model is None:
            raise ValueError("Model not created.")
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")
    
    def load_model(self, path: str) -> None:
        """Load model state dict."""
        if self.model is None:
            self.create_model()
        # Ensure checkpoints saved on GPU can be loaded on CPU-only machines.
        self.model.load_state_dict(torch.load(path, map_location=self.device))
        self.model.eval()
        print(f"Model loaded from {path}")


def get_covid_data(
    data_dir: str = "/path/to/your/covid19-radiography-database/COVID-19_Radiography_Dataset",
    model_path: str = None,
    train_model: bool = True,
    num_epochs: int = 10,
    batch_size: int = 32,
    seed: int = 42
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Main function to get COVID-19 dataset probabilities and labels.
    
    Args:
        data_dir: Root directory of the dataset
        model_path: Path to save/load model weights
        train_model: Whether to train the model (False to load existing)
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        seed: Random seed for reproducibility
        
    Returns:
        Tuple of (cal_probs, cal_labels, test_probs, test_labels)
    """
    print("=" * 60)
    print("COVID-19 X-ray Dataset Processing Pipeline")
    print("=" * 60)
    
    # Step 1: Data Processing
    print("\n[1/4] Loading and preprocessing data...")
    processor = CovidDataProcessor(data_dir, seed=seed)
    all_data = processor.load_data()
    splits = processor.split_data(all_data)
    dataloaders = processor.create_dataloaders(splits, batch_size=batch_size)
    
    # Step 2: Model Setup
    print("\n[2/4] Setting up model...")
    trainer = InceptionV3Trainer(num_classes=4)
    trainer.create_model()
    
    # Step 3: Training or Loading
    if train_model:
        print("\n[3/4] Training model...")
        trainer.train(dataloaders['train'], num_epochs=num_epochs)
        if model_path:
            trainer.save_model(model_path)
    else:
        print("\n[3/4] Loading pre-trained model...")
        if model_path is None:
            raise ValueError("model_path must be provided when train_model=False")
        trainer.load_model(model_path)
    
    # Step 4: Extract Probabilities
    print("\n[4/4] Extracting probabilities...")
    cal_probs, cal_labels = trainer.extract_probabilities(dataloaders['cal'])
    test_probs, test_labels = trainer.extract_probabilities(dataloaders['test'])
    
    print(f"\nCalibration set: {cal_probs.shape}, {cal_labels.shape}")
    print(f"Test set: {test_probs.shape}, {test_labels.shape}")
    
    print("\n" + "=" * 60)
    print("Pipeline completed successfully!")
    print("=" * 60)
    
    return cal_probs, cal_labels, test_probs, test_labels


if __name__ == "__main__":
    # Example usage
    cal_probs, cal_labels, test_probs, test_labels = get_covid_data(
        num_epochs=10,
        model_path="covid_inception_model.pth"
    )
    
    print("\nData ready for conformal prediction!")
    print(f"Calibration probabilities shape: {cal_probs.shape}")
    print(f"Test probabilities shape: {test_probs.shape}")
