"""
ResNet fine-tuning model for delta regression.

End-to-end fine-tuning to predict the signed margin (delta) directly.
Used for image datasets (turkey, inaturalist).
"""

from typing import Optional, Union, List
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
import wandb

from models.base import BaseModel


class RegressionImageDataset(Dataset):
    """Image dataset for regression fine-tuning."""

    def __init__(
        self,
        image_paths: List[str],
        targets: np.ndarray,
        transform=None,
    ):
        self.image_paths = image_paths
        self.targets = targets.astype(np.float32)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return {
            "pixel_values": image,
            "targets": torch.tensor(self.targets[idx], dtype=torch.float32),
            "idx": torch.tensor(idx, dtype=torch.long),
        }


class ResNetFinetuneRegressionModel(BaseModel):
    """
    End-to-end ResNet fine-tuning model for delta regression.

    Predicts the signed margin (delta) directly using MSE loss.
    At inference, classification is done by thresholding at 0.
    """

    def __init__(
        self,
        task: str = 'regression',
        backbone: str = 'resnet50',
        epochs: int = 10,
        batch_size: int = 32,
        lr: float = 1e-4,
        weight_decay: float = 1e-4,
        img_size: int = 224,
        freeze_backbone: bool = False,
        **kwargs,
    ):
        if task != 'regression':
            raise ValueError("ResNetFinetuneRegressionModel only supports regression")

        super().__init__(task=task, **kwargs)
        self.backbone = backbone
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay
        self.img_size = img_size
        self.freeze_backbone = freeze_backbone

        self.model = None
        self.device = None

    def _build_model(self):
        """Build ResNet model with regression head (single output)."""
        if self.backbone == 'resnet50':
            model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        elif self.backbone == 'resnet18':
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        elif self.backbone == 'resnet101':
            model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
        else:
            raise ValueError(f"Unknown backbone: {self.backbone}")

        # Optionally freeze backbone
        if self.freeze_backbone:
            for param in model.parameters():
                param.requires_grad = False

        # Replace classification head with regression head (single output)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 1)

        return model

    def _get_transforms(self, train: bool = True):
        """Get image transforms for training or evaluation."""
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        if train:
            return transforms.Compose([
                transforms.RandomResizedCrop(self.img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(self.img_size),
                transforms.ToTensor(),
                normalize,
            ])

    def fit(
        self,
        X: Union[List[str], np.ndarray],
        y: np.ndarray,
        sample_weight: Optional[np.ndarray] = None,
    ) -> 'ResNetFinetuneRegressionModel':
        """Fine-tune ResNet for regression on the training data.

        Args:
            X: List of image paths
            y: Target values (delta values for regression)
            sample_weight: Optional per-example weights
        """
        # Initialize wandb
        wandb.init(
            project="cost-sensitive-learning",
            config={
                "model": "resnet_finetune_regression",
                "backbone": self.backbone,
                "epochs": self.epochs,
                "batch_size": self.batch_size,
                "lr": self.lr,
            },
            reinit=True,
        )

        # Setup device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Convert X to list of paths if needed
        if isinstance(X, np.ndarray):
            X = X.tolist()

        # Create dataset
        train_transform = self._get_transforms(train=True)
        train_dataset = RegressionImageDataset(X, y, transform=train_transform)
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )

        # Build model
        self.model = self._build_model()
        self.model = self.model.to(self.device)

        # Optimizer
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
        )

        # LR scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.epochs
        )

        # Convert sample weights to tensor if provided
        if sample_weight is not None:
            sample_weight = torch.tensor(sample_weight, dtype=torch.float32)

        # Training loop
        self.model.train()
        for epoch in range(self.epochs):
            total_loss = 0.0
            num_batches = 0

            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")
            for batch in pbar:
                pixel_values = batch["pixel_values"].to(self.device)
                targets = batch["targets"].to(self.device)
                idx = batch["idx"]

                optimizer.zero_grad()

                outputs = self.model(pixel_values).squeeze(-1)

                # MSE loss (per-example)
                loss = (outputs - targets) ** 2

                # Apply sample weights if provided
                if sample_weight is not None:
                    weights = sample_weight[idx].to(self.device)
                    loss = loss * weights

                loss = loss.mean()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
                pbar.set_postfix({"loss": f"{total_loss/num_batches:.4f}"})

            scheduler.step()

            # Log to wandb
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": total_loss / num_batches,
                "lr": scheduler.get_last_lr()[0],
            })

        wandb.finish()
        self.is_fitted_ = True
        return self

    def predict(self, X: Union[List[str], np.ndarray]) -> np.ndarray:
        """Predict delta values."""
        self._check_fitted()

        if isinstance(X, np.ndarray):
            X = X.tolist()

        eval_transform = self._get_transforms(train=False)
        eval_dataset = RegressionImageDataset(X, np.zeros(len(X)), transform=eval_transform)
        eval_loader = DataLoader(
            eval_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

        self.model.eval()
        all_preds = []

        with torch.no_grad():
            for batch in eval_loader:
                pixel_values = batch["pixel_values"].to(self.device)
                outputs = self.model(pixel_values).squeeze(-1)
                all_preds.append(outputs.cpu().numpy())

        return np.concatenate(all_preds)

    def predict_class(self, X: Union[List[str], np.ndarray]) -> np.ndarray:
        """Predict class labels by thresholding delta at 0."""
        delta_pred = self.predict(X)
        return (delta_pred >= 0).astype(int)

    def __repr__(self) -> str:
        return f"ResNetFinetuneRegressionModel(backbone={self.backbone!r}, epochs={self.epochs}, batch_size={self.batch_size})"
