import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split, WeightedRandomSampler, Dataset
from torchvision import datasets, transforms
from collections import OrderedDict
import numpy as np
import random
from PIL import Image
import wandb
import argparse
import os
import sys
from sklearn.metrics import f1_score
from collections import defaultdict

# Constants
DATA_DIR = "./data"
BATCH_SIZE = 128
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
CONTAM_PROB = 0
LAMBDA_REG = 0 # Regularization strength
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        from PIL import Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label


class ContaminatedDataset(Dataset):
    def __init__(self, base_dataset, classes_to_contaminate, contamination_rate, color_mapping, patch_size=3, transform=None):
        self.base = base_dataset
        self.dataset = base_dataset
        self.classes = set(classes_to_contaminate)
        self.rate = contamination_rate
        self.color_map = color_mapping
        self.patch_size = patch_size
        self.transform = transform

        # Extract labels
        labels = np.array(self.base.targets)

        # Precompute which indices to contaminate
        contam_indices = []
        for cls in self.classes:
            cls_idx = np.where(labels == cls)[0]
            k = int(len(cls_idx) * self.rate)
            chosen = np.random.choice(cls_idx, k, replace=False)
            contam_indices.extend(chosen.tolist())
        self.contaminated_set = set(contam_indices)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        contaminated = False
        if idx in self.contaminated_set and label in self.classes:
            contaminated = True
            arr = np.array(img).copy()
            h, w, _ = arr.shape
            x1 = w - self.patch_size
            y1 = h - self.patch_size
            arr[y1:h, x1:w, :] = self.color_map[label]
            img = Image.fromarray(arr)
        if self.transform:
            img = self.transform(img)
        return img, label, contaminated

def train_and_filter_dataset(contaminated_dataset, percentile=25.0, num_epochs=10, batch_size=128, learning_rate=0.001):
    num_classes = 10
    num_channels = 3
    width = 32
    batch_size = 128
    num_epochs = 5
    lr = 1e-3
    train_dataset = contaminated_dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    # Model
    model = nn.Sequential(
        OrderedDict([
            ('conv0', nn.Conv2d(num_channels, width, kernel_size=3, padding=1)),
            ('relu0', nn.ReLU()),
            ('conv1', nn.Conv2d(width, 2*width, kernel_size=3, padding=1)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(2*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu2', nn.ReLU()),
            ('pool0', nn.MaxPool2d(3)),
            ('conv3', nn.Conv2d(4*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu3', nn.ReLU()),
            ('pool1', nn.AdaptiveAvgPool2d(1)),
            ('flatten', nn.Flatten()),
            ('linear', nn.Linear(4*width, num_classes)),
        ])
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Loss + optimizer
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train and evaluate contaminated below threshold each epoch
    for epoch in range(1, num_epochs + 1):
        # --- Training ---
        model.train()
        train_loss_sum = 0.0
        for imgs, labels, _ in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            losses = criterion(outputs, labels)
            loss = losses.mean()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item() * imgs.size(0)
        avg_train_loss = train_loss_sum / len(train_loader.dataset)
        print(f"Epoch {epoch}/{num_epochs}, Avg Training Loss: {avg_train_loss:.4f}")

        # --- Compute 25th percentile and fraction contaminated below ---
        model.eval()
        all_losses, all_flags = [], []
        with torch.no_grad():
            for imgs, labels, flags in train_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                batch_losses = criterion(outputs, labels).cpu().numpy()
                all_losses.extend(batch_losses.tolist())
                all_flags.extend(flags.cpu().numpy().tolist())
        all_losses_arr = np.array(all_losses)
        flags_arr = np.array(all_flags)
        threshold = np.percentile(all_losses_arr, 25.5)
        if flags_arr.sum() > 0:
            # If there is any contamination
            frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
        else:
            # There is no contamination so set the fraction to 0
            frac_cont_below = 0
        print(f"  Threshold (25th pct): {threshold}")
        print(f"  Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Compute per-sample losses for the entire dataset.
    model.eval()
    all_losses = []
    # Using another DataLoader without shuffling (to preserve order).
    full_loader = DataLoader(contaminated_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for images, labels, _ in full_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            losses = criterion(outputs, labels)  # Per-sample losses
            all_losses.extend(losses.cpu().numpy().tolist())

    all_losses_np = np.array(all_losses)
    # Calculate the 25th percentile and filter indices.
    threshold = np.percentile(all_losses_np, percentile)
    if flags_arr.sum() > 0:
        contam_losses_np = all_losses_np[np.array(list(contaminated_dataset.contaminated_set))]
        frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
    else:
        frac_cont_below = 0
    print(f"Loss Threshold: {threshold:.4f}")
    print(f"Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Filter out (i.e. remove) points whose loss is below the threshold.
    # We keep only samples with loss >= threshold.
    filtered_indices = [idx for idx, loss_val in enumerate(all_losses_np) if loss_val >= threshold]
    print(f"Filtered dataset size: {len(filtered_indices)} out of {len(all_losses_np)} samples.")
    
    # Create a filtered dataset as a subset using the selected indices.
    filtered_dataset = Subset(contaminated_dataset, filtered_indices)
    return model, threshold, filtered_dataset, all_losses_np

def get_weighted_sampler(subset, num_classes=10):
    labels = extract_targets(subset)
    class_counts = np.bincount(labels, minlength=num_classes)
    class_weights = 1. / (class_counts + 1e-6)
    sample_weights = [class_weights[label] for label in labels]
    return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

def extract_targets(subset):
    indices = None
    while isinstance(subset, Subset):
        indices = subset.indices if indices is None else [subset.indices[i] for i in indices]
        subset = subset.dataset
    subset = subset.dataset
    if hasattr(subset, "targets"):
        return [subset.targets[i] for i in indices]
    elif hasattr(subset, "labels"):
        return [subset.labels[i] for i in indices]
    else:
        raise AttributeError("Underlying dataset does not have 'targets' or 'labels'.")

# Main training script
def main():
    parser = argparse.ArgumentParser(description='CIFAR10 with shortcut noise and L2 regularization')
    parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
    parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
    parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
    parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
    parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
    args = parser.parse_args()
    print(args)
    eps = args.eps
    annp = args.annp
    run = args.run
    size_percent = args.size_percent
    filename = f"cifar10weightedlossv3_shortcut_patchreg_partial_{size_percent}_eps_{eps}_annp_{annp}_run_{run}.npz"
    if os.path.isfile(filename) and (not args.force):
        # File exists
        print(f"File {filename} exists")
        print("=" * 50)
        sys.exit(0)
    # Define transforms for train/val/test
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    base_train = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=None
    )
    base_train = BalancedCIFAR10(train=True, download=True, transform=None, percentage=size_percent / 100)
    classes_to_contaminate = [0,2,4]
    contamination_rate = eps
    color_mapping = {0: (255, 0, 0), 2: (0, 0, 255), 4: (0, 255, 0)}
    patch_size = 3
    train_dataset = ContaminatedDataset(
        base_train,
        classes_to_contaminate,
        contamination_rate,
        color_mapping,
        patch_size,
        transform
    )
    if annp != 100.0:
        _, _, full_train_dataset, _ = train_and_filter_dataset(train_dataset, percentile=100.-annp)
    else:
        full_train_dataset = train_dataset
    # Split into train and validation sets (e.g., 80% train, 20% validation)
    train_size = int(0.8 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
    
    # Test set is loaded normally without contamination
    test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    num_channels = 3
    width = 32  # image width
    num_classes = 10

    num_models = 9
    lambda_regs = torch.logspace(-6,1,num_models-1).to(device)
    all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_f1 = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_best_val_test_acc = torch.zeros(num_models).to(device)
    for i, lambda_reg in enumerate(lambda_regs):
        print(lambda_reg)
        # Initialize wandb
        wandb.init(project="cifar10_contaminated_weighted", config={
            "batch_size": BATCH_SIZE,
            "num_epochs": NUM_EPOCHS,
            "learning_rate": LEARNING_RATE,
            "architecture": "custom CNN as provided",
            "contamination_prob": eps,
            "lambda_reg": lambda_reg
        })
        
        model = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                    ("relu0", nn.ReLU()),
                    ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                    ("relu1", nn.ReLU()),
                    ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu2", nn.ReLU()),
                    ("pool0", nn.MaxPool2d(3)),
                    ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu3", nn.ReLU()),
                    ("pool1", nn.AdaptiveAvgPool2d(1)),
                    ("flatten", nn.Flatten()),
                    ("linear", nn.Linear(4 * width, num_classes)),
                ]
            )
        ).cuda()  # Move model to GPU if available

        # Loss and optimizer
        class_labels = extract_targets(train_dataset)
        class_counts = np.bincount(class_labels, minlength=num_classes) / len(train_dataset)
        print(np.sum(class_counts))
        weights = 1. / (class_counts + 1e-6)
        weight_tensor = torch.tensor(weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=weight_tensor)
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        best_val_loss = float('inf')
        best_test_acc = 0.0

        # Training loop
        for epoch in range(1, NUM_EPOCHS + 1):
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, targets, _ in train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                inputs.requires_grad = True
                
                outputs = model(inputs)
                logits = outputs
                labels = targets
                loss_ce = torch.nn.functional.cross_entropy(logits, labels)

                # Compute gradient w.r.t. input
                scalar_proxy = logits.gather(1, labels.view(-1, 1)).sum()
                grad_input = torch.autograd.grad(scalar_proxy, inputs, create_graph=True)[0]  # shape: (B, 3, 32, 32)
                B, C, H, W = grad_input.shape
                mask = torch.zeros_like(grad_input)
                mask[:, :, H-3:H, W-3:W] = 1.0

                # Masked gradient norm
                grad_patch = grad_input * mask
                grad_norm = grad_patch.view(B, -1).norm(p=2, dim=1).mean()

                # Total loss and update
                loss = loss_ce + lambda_reg * grad_norm
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()

            train_loss /= train_total
            train_acc = train_correct / train_total

            # Validation
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            val_all_preds = []
            val_all_labels = []
            with torch.no_grad():
                for inputs, targets, _ in val_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()

            val_loss /= val_total
            val_acc = val_correct / val_total

            # Test evaluation at each epoch to track performance
            model.eval()
            test_loss = 0.0
            test_correct = 0
            test_total = 0
            test_all_preds = []
            test_all_labels = []

            with torch.no_grad():
                for inputs, targets in test_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    test_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    test_total += targets.size(0)
                    test_correct += predicted.eq(targets).sum().item()
                    test_all_preds.extend(predicted.cpu().numpy())
                    test_all_labels.extend(targets.cpu().numpy())

            test_loss /= test_total
            test_acc = test_correct / test_total
            test_f1 = f1_score(test_all_labels, test_all_preds, average="macro")

            # Save best test accuracy based on minimal val loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_test_acc = test_acc
            # Log metrics to wandb
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "test_loss": test_loss,
                "test_accuracy": test_acc,
                "test_f1": test_f1,
            })
            all_train_losses[epoch-1, i] = train_loss
            all_train_acc[epoch-1, i] = train_acc
            all_val_losses[epoch-1, i] = val_loss
            all_val_acc[epoch-1, i] = val_acc
            all_test_losses[epoch-1, i] = test_loss
            all_test_acc[epoch-1, i] = test_acc
            all_test_f1[epoch-1, i] = test_f1

            print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
                f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
                f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

        # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
        wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
        all_best_val_test_acc[i] = best_test_acc
        print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

        wandb.finish()
    np.savez(filename, all_train_losses.detach().cpu().numpy(),
             all_train_acc.detach().cpu().numpy(),
             all_val_losses.detach().cpu().numpy(),
             all_val_acc.detach().cpu().numpy(),
             all_test_losses.detach().cpu().numpy(),
             all_test_acc.detach().cpu().numpy(),
             all_best_val_test_acc.detach().cpu().numpy(),
             all_test_f1.detach().cpu().numpy())
    print("=" * 50)
if __name__ == '__main__':
    main()

