import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split, WeightedRandomSampler, Dataset
from torchvision import datasets, transforms
from collections import OrderedDict
import numpy as np
import random
from PIL import Image
import wandb
import argparse
import os
import sys
from sklearn.metrics import f1_score
from collections import defaultdict

# Constants
DATA_DIR = "./data"
BATCH_SIZE = 128
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
PATCH_SIZE = 3  # Size of the patch (n x n) for contamination
CONTAM_PROB = 0
LAMBDA_REG = 0 #5e-3  # Regularization strength
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        from PIL import Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label


class ContaminatedDataset(Dataset):
    """
    Wraps a base dataset, contaminating a fraction of specified classes by drawing a colored patch
    in the lower-right corner of each image.
    """
    def __init__(self, base_dataset, classes_to_contaminate, contamination_rate, color_mapping, patch_size=3, transform=None):
        self.base = base_dataset
        self.dataset = base_dataset
        self.classes = set(classes_to_contaminate)
        self.rate = contamination_rate
        self.color_map = color_mapping
        self.patch_size = patch_size
        self.transform = transform

        # Extract labels
        labels = np.array(self.base.targets)

        # Precompute which indices to contaminate
        contam_indices = []
        for cls in self.classes:
            cls_idx = np.where(labels == cls)[0]
            k = int(len(cls_idx) * self.rate)
            chosen = np.random.choice(cls_idx, k, replace=False)
            contam_indices.extend(chosen.tolist())
        self.contaminated_set = set(contam_indices)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        contaminated = False
        if idx in self.contaminated_set and label in self.classes:
            contaminated = True
            arr = np.array(img).copy()
            h, w, _ = arr.shape
            x1 = w - self.patch_size
            y1 = h - self.patch_size
            arr[y1:h, x1:w, :] = self.color_map[label]
            img = Image.fromarray(arr)
        if self.transform:
            img = self.transform(img)
        return img, label, contaminated


def train_and_filter_dataset(contaminated_dataset, percentile=25.0, num_epochs=10, batch_size=128, learning_rate=0.001):
    """
    Trains a CNN on the supplied contaminated dataset, computes per-sample losses,
    determines the 25th percentile loss value, and filters out all points whose loss
    is below that threshold.

    Parameters:
        contaminated_dataset (torch.utils.data.Dataset): 
            A dataset that returns (image, label, contamination_flag).
        num_epochs (int): Number of training epochs.
        batch_size (int): Batch size used in training and loss computation.
        learning_rate (float): Learning rate for the optimizer.

    Returns:
        model (torch.nn.Module): The trained model.
        threshold_25 (float): The 25th percentile loss value.
        filtered_dataset (torch.utils.data.Subset): Subset of the original dataset
            containing only samples whose loss >= threshold_25.
        all_losses_np (np.ndarray): An array of loss values computed for all samples.
    """
    num_classes = 10
    num_channels = 3
    width = 32
    batch_size = 128
    num_epochs = 5
    lr = 1e-3
    train_dataset = contaminated_dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    # Model
    model = nn.Sequential(
        OrderedDict([
            ('conv0', nn.Conv2d(num_channels, width, kernel_size=3, padding=1)),
            ('relu0', nn.ReLU()),
            ('conv1', nn.Conv2d(width, 2*width, kernel_size=3, padding=1)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(2*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu2', nn.ReLU()),
            ('pool0', nn.MaxPool2d(3)),
            ('conv3', nn.Conv2d(4*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu3', nn.ReLU()),
            ('pool1', nn.AdaptiveAvgPool2d(1)),
            ('flatten', nn.Flatten()),
            ('linear', nn.Linear(4*width, num_classes)),
        ])
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Loss & optimizer
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train and evaluate contaminated below threshold each epoch
    for epoch in range(1, num_epochs + 1):
        # --- Training ---
        model.train()
        train_loss_sum = 0.0
        for imgs, labels, _ in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            losses = criterion(outputs, labels)
            loss = losses.mean()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item() * imgs.size(0)
        avg_train_loss = train_loss_sum / len(train_loader.dataset)
        print(f"Epoch {epoch}/{num_epochs}, Avg Training Loss: {avg_train_loss:.4f}")

        # --- Compute 25th percentile and fraction contaminated below ---
        model.eval()
        all_losses, all_flags = [], []
        with torch.no_grad():
            for imgs, labels, flags in train_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                batch_losses = criterion(outputs, labels).cpu().numpy()
                # print(flags.cpu().numpy().tolist())
                all_losses.extend(batch_losses.tolist())
                all_flags.extend(flags.cpu().numpy().tolist())
        all_losses_arr = np.array(all_losses)
        flags_arr = np.array(all_flags)
        threshold = np.percentile(all_losses_arr, 25.5)
        if flags_arr.sum() > 0:
            frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
        else:
            frac_cont_below = 0
        # all_eff[r, epoch-1] = frac_cont_below
        print(f"  Threshold (25th pct): {threshold}")
        print(f"  Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Compute per-sample losses for the entire dataset.
    model.eval()
    all_losses = []
    # Using another DataLoader without shuffling (to preserve order).
    full_loader = DataLoader(contaminated_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for images, labels, _ in full_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            losses = criterion(outputs, labels)  # Per-sample losses
            all_losses.extend(losses.cpu().numpy().tolist())

    all_losses_np = np.array(all_losses)
    # Calculate the 25th percentile and filter indices.
    threshold = np.percentile(all_losses_np, percentile)
    if flags_arr.sum() > 0:
        contam_losses_np = all_losses_np[np.array(list(contaminated_dataset.contaminated_set))]
        frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
    else:
        frac_cont_below = 0
    print(f"Loss Threshold: {threshold:.4f}")
    print(f"Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Filter out (i.e. remove) points whose loss is below the threshold.
    # We keep only samples with loss >= threshold.
    filtered_indices = [idx for idx, loss_val in enumerate(all_losses_np) if loss_val >= threshold]
    print(f"Filtered dataset size: {len(filtered_indices)} out of {len(all_losses_np)} samples.")
    
    # Create a filtered dataset as a subset using the selected indices.
    filtered_dataset = Subset(contaminated_dataset, filtered_indices)
    return model, threshold, filtered_dataset, all_losses_np

def get_weighted_sampler(subset, num_classes=10):
    labels = extract_targets(subset)
    class_counts = np.bincount(labels, minlength=num_classes)
    class_weights = 1. / (class_counts + 1e-6)
    sample_weights = [class_weights[label] for label in labels]
    return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

def extract_targets(subset):
    indices = None
    while isinstance(subset, Subset):
        indices = subset.indices if indices is None else [subset.indices[i] for i in indices]
        subset = subset.dataset
    # print(indices)
    # print([subset.dataset.targets[i] for i in indices])
    subset = subset.dataset
    if hasattr(subset, "targets"):
        return [subset.targets[i] for i in indices]
    elif hasattr(subset, "labels"):
        return [subset.labels[i] for i in indices]
    else:
        raise AttributeError("Underlying dataset does not have 'targets' or 'labels'.")

# Main training script
def main():
    parser = argparse.ArgumentParser(description='CIFAR10 with shortcut noise and L2 regularization')
    parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
    parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
    parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
    parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
    parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
    args = parser.parse_args()
    print(args)
    eps = args.eps
    annp = args.annp
    run = args.run
    size_percent = args.size_percent
    filename = f"cifar10weightedlossv3_shortcut_pweight_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}.npz"
    if os.path.isfile(filename) and (not args.force):
        # File exists
        print(f"File {filename} exists")
        print("=" * 50)
        sys.exit(0)
    # Define transforms for train/val/test
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    base_train = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=None
    )
    base_train = BalancedCIFAR10(train=True, download=True, transform=None, percentage=size_percent / 100)
    # Load and contaminate train (and later split validation) data
    classes_to_contaminate = [0,2,4]
    contamination_rate = eps
    color_mapping = {0: (255, 0, 0), 2: (0, 0, 255), 4: (0, 255, 0)}
    patch_size = 3
    train_dataset = ContaminatedDataset(
        base_train,
        classes_to_contaminate,
        contamination_rate,
        color_mapping,
        patch_size,
        transform
    )
    if annp != 100.0:
        _, _, full_train_dataset, _ = train_and_filter_dataset(train_dataset, percentile=100.-annp)
    else:
        full_train_dataset = train_dataset
    # Split into train and validation sets (e.g., 80% train, 20% validation)
    train_size = int(0.8 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
    
    # Test set is loaded normally without contamination
    test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # Define the CNN model. For CIFAR10, image width is 32, and num_channels is 3.
    num_channels = 3
    width = 32  # image width
    num_classes = 10

    num_models = 10
    lambda_regs = np.linspace(0.0,2.0,num_models)
    all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_f1 = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_best_val_test_acc = torch.zeros(num_models).to(device)
    for i, lambda_reg in enumerate(lambda_regs):
        print(lambda_reg)
        # Initialize wandb
        wandb.init(project="cifar10_contaminated_weighted", config={
            "batch_size": BATCH_SIZE,
            "num_epochs": NUM_EPOCHS,
            "learning_rate": LEARNING_RATE,
            "architecture": "custom CNN as provided",
            "contamination_prob": eps,
            "lambda_reg": lambda_reg
        })

        model = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                    ("relu0", nn.ReLU()),
                    ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                    ("relu1", nn.ReLU()),
                    ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu2", nn.ReLU()),
                    ("pool0", nn.MaxPool2d(3)),
                    ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu3", nn.ReLU()),
                    ("pool1", nn.AdaptiveAvgPool2d(1)),
                    ("flatten", nn.Flatten()),
                    ("linear", nn.Linear(4 * width, num_classes)),
                ]
            )
        ).cuda()  # Move model to GPU if available

        # Loss and optimizer
        class_labels = extract_targets(train_dataset)
        class_counts = np.bincount(class_labels, minlength=num_classes) / len(train_dataset)
        print(np.sum(class_counts))
        weights = 1. / (class_counts + 1e-6)
        weight_tensor = torch.tensor(weights ** lambda_reg, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=weight_tensor)
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        best_val_loss = float('inf')
        best_test_acc = 0.0

        # Training loop
        for epoch in range(1, NUM_EPOCHS + 1):
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, targets, _ in train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                inputs.requires_grad = True

                optimizer.zero_grad()
                
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()

            train_loss /= train_total
            train_acc = train_correct / train_total

            # Validation
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            val_all_preds = []
            val_all_labels = []
            with torch.no_grad():
                for inputs, targets, _ in val_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()

            val_loss /= val_total
            val_acc = val_correct / val_total

            # Test evaluation at each epoch to track performance
            model.eval()
            test_loss = 0.0
            test_correct = 0
            test_total = 0
            test_all_preds = []
            test_all_labels = []

            with torch.no_grad():
                for inputs, targets in test_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    test_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    test_total += targets.size(0)
                    test_correct += predicted.eq(targets).sum().item()
                    test_all_preds.extend(predicted.cpu().numpy())
                    test_all_labels.extend(targets.cpu().numpy())

            test_loss /= test_total
            test_acc = test_correct / test_total
            test_f1 = f1_score(test_all_labels, test_all_preds, average="macro")

            # Save best test accuracy based on minimal val loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_test_acc = test_acc
            # Log metrics to wandb
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "test_loss": test_loss,
                "test_accuracy": test_acc,
                "test_f1": test_f1,
            })
            all_train_losses[epoch-1, i] = train_loss
            all_train_acc[epoch-1, i] = train_acc
            all_val_losses[epoch-1, i] = val_loss
            all_val_acc[epoch-1, i] = val_acc
            all_test_losses[epoch-1, i] = test_loss
            all_test_acc[epoch-1, i] = test_acc
            all_test_f1[epoch-1, i] = test_f1

            print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
                f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
                f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

        # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
        wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
        all_best_val_test_acc[i] = best_test_acc
        print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

        wandb.finish()
    np.savez(filename, all_train_losses.detach().cpu().numpy(),
             all_train_acc.detach().cpu().numpy(),
             all_val_losses.detach().cpu().numpy(),
             all_val_acc.detach().cpu().numpy(),
             all_test_losses.detach().cpu().numpy(),
             all_test_acc.detach().cpu().numpy(),
             all_best_val_test_acc.detach().cpu().numpy(),
             all_test_f1.detach().cpu().numpy())
    print("=" * 50)
if __name__ == '__main__':
    main()

