"""
ResNet fine-tuning model for cost-sensitive learning.

End-to-end fine-tuning with optional per-example weighting.
Used for image classification on turkey and inaturalist datasets.
"""

from typing import Optional, Union, List
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
from tqdm import tqdm

from models.base import BaseModel


class ImageDataset(Dataset):
    """Simple image dataset for fine-tuning."""

    def __init__(
        self,
        image_paths: List[str],
        labels: np.ndarray,
        transform=None,
    ):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return {
            "pixel_values": image,
            "labels": torch.tensor(int(self.labels[idx]), dtype=torch.long),
            "idx": torch.tensor(idx, dtype=torch.long),
        }


class ResNetFinetuneModel(BaseModel):
    """
    End-to-end ResNet fine-tuning model.

    Supports weighted cross-entropy for cost-sensitive learning.
    """

    def __init__(
        self,
        task: str = 'classification',
        backbone: str = 'resnet50',
        epochs: int = 10,
        batch_size: int = 32,
        lr: float = 1e-4,
        weight_decay: float = 1e-4,
        img_size: int = 224,
        freeze_backbone: bool = False,
        **kwargs,
    ):
        if task != 'classification':
            raise ValueError("ResNetFinetuneModel only supports classification")

        super().__init__(task=task, **kwargs)
        self.backbone = backbone
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.img_size = img_size
        self.freeze_backbone = freeze_backbone

        self.model = None
        self.device = None

    def _build_model(self, num_classes: int = 2):
        """Build ResNet model with custom classification head."""
        if self.backbone == 'resnet50':
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        elif self.backbone == 'resnet18':
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif self.backbone == 'resnet101':
            model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        else:
            raise ValueError(f"Unknown backbone: {self.backbone}")

        # Optionally freeze backbone
        if self.freeze_backbone:
            for param in model.parameters():
                param.requires_grad = False

        # Replace classification head
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, num_classes)

        return model

    def _get_transforms(self, train: bool = True):
        """Get image transforms for training or evaluation."""
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        if train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(self.img_size),
                transforms.ToTensor(),
                normalize,
            ])

    def fit(
        self,
        X: Union[List[str], np.ndarray],
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'ResNetFinetuneModel':
        """Fine-tune ResNet on the training data."""
        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Convert X to list of paths if needed
        if isinstance(X, np.ndarray):
            X = X.tolist()

        # Create dataset
        train_transform = self._get_transforms(train=True)
        train_dataset = ImageDataset(X, y, transform=train_transform)
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )

        # Build model
        num_classes = len(np.unique(y))
        self.model = self._build_model(num_classes=num_classes)
        self.model = self.model.to(self.device)

        # Optimizer
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )

        # LR scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.epochs
        )

        # Loss function (per-example, no reduction)
        criterion = nn.CrossEntropyLoss(reduction='none')

        # Convert sample weights to tensor if provided
        if sample_weight is not None:
            sample_weight = torch.tensor(sample_weight, dtype=torch.float32)

        # Training loop
        self.model.train()
        for epoch in range(self.epochs):
            total_loss = 0.0
            num_batches = 0

            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")
            for batch in pbar:
                pixel_values = batch["pixel_values"].to(self.device)
                labels = batch["labels"].to(self.device)
                idx = batch["idx"]

                optimizer.zero_grad()

                outputs = self.model(pixel_values)
                loss = criterion(outputs, labels)

                # Apply sample weights if provided
                if sample_weight is not None:
                    weights = sample_weight[idx].to(self.device)
                    loss = loss * weights

                loss = loss.mean()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
                pbar.set_postfix({"loss": f"{total_loss/num_batches:.4f}"})

            scheduler.step()

        self.is_fitted_ = True
        return self

    def predict(self, X: Union[List[str], np.ndarray]) -> np.ndarray:
        """Predict class labels."""
        self._check_fitted()

        if isinstance(X, np.ndarray):
            X = X.tolist()

        eval_transform = self._get_transforms(train=False)
        eval_dataset = ImageDataset(X, np.zeros(len(X)), transform=eval_transform)
        eval_loader = DataLoader(
            eval_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

        self.model.eval()
        all_preds = []

        with torch.no_grad():
            for batch in eval_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                outputs = self.model(pixel_values)
                preds = torch.argmax(outputs, dim=-1)
                all_preds.append(preds.cpu().numpy())

        return np.concatenate(all_preds)

    def predict_proba(self, X: Union[List[str], np.ndarray]) -> np.ndarray:
        """Predict class probabilities."""
        self._check_fitted()

        if isinstance(X, np.ndarray):
            X = X.tolist()

        eval_transform = self._get_transforms(train=False)
        eval_dataset = ImageDataset(X, np.zeros(len(X)), transform=eval_transform)
        eval_loader = DataLoader(
            eval_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

        self.model.eval()
        all_probs = []

        with torch.no_grad():
            for batch in eval_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                outputs = self.model(pixel_values)
                probs = torch.softmax(outputs, dim=-1)
                all_probs.append(probs.cpu().numpy())

        return np.concatenate(all_probs)

    def __repr__(self) -> str:
        return f"ResNetFinetuneModel(backbone={self.backbone!r}, epochs={self.epochs}, batch_size={self.batch_size})"
