import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split, WeightedRandomSampler, Dataset
from torchvision import datasets, transforms
from collections import OrderedDict
import numpy as np
import random
from PIL import Image
import wandb
import argparse
import os
import sys
from sklearn.metrics import f1_score
from tqdm import tqdm as tqdm

# Constants
DATA_DIR = "./data"
BATCH_SIZE = 128
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
CONTAM_PROB = 0
LAMBDA_REG = 0 # Regularization strength
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ContaminatedDataset(Dataset):
    def __init__(self, base_dataset, classes_to_contaminate, contamination_rate, color_mapping, patch_size=3, transform=None):
        self.base = base_dataset
        self.dataset = base_dataset
        self.classes = set(classes_to_contaminate)
        self.rate = contamination_rate
        self.color_map = color_mapping
        self.patch_size = patch_size
        self.transform = transform

        # Extract labels
        labels = np.array(self.base.targets)

        # Precompute which indices to contaminate
        contam_indices = []
        for cls in self.classes:
            cls_idx = np.where(labels == cls)[0]
            k = int(len(cls_idx) * self.rate)
            chosen = np.random.choice(cls_idx, k, replace=False)
            contam_indices.extend(chosen.tolist())
        self.contaminated_set = set(contam_indices)

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, label = self.base[idx]
        contaminated = False
        if idx in self.contaminated_set and label in self.classes:
            contaminated = True
            arr = np.array(img).copy()
            h, w, _ = arr.shape
            x1 = w - self.patch_size
            y1 = h - self.patch_size
            arr[y1:h, x1:w, :] = self.color_map[label]
            img = Image.fromarray(arr)
        if self.transform:
            img = self.transform(img)
        return img, label, idx, contaminated

def train_and_filter_dataset(contaminated_dataset, percentile=25.0, num_epochs=10, batch_size=128, learning_rate=0.001):

    num_classes = 10
    num_channels = 3
    width = 32
    batch_size = 128
    num_epochs = 5
    lr = 1e-3
    train_dataset = contaminated_dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    # Model
    model = nn.Sequential(
        OrderedDict([
            ('conv0', nn.Conv2d(num_channels, width, kernel_size=3, padding=1)),
            ('relu0', nn.ReLU()),
            ('conv1', nn.Conv2d(width, 2*width, kernel_size=3, padding=1)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(2*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu2', nn.ReLU()),
            ('pool0', nn.MaxPool2d(3)),
            ('conv3', nn.Conv2d(4*width, 4*width, kernel_size=3, stride=2, padding=1)),
            ('relu3', nn.ReLU()),
            ('pool1', nn.AdaptiveAvgPool2d(1)),
            ('flatten', nn.Flatten()),
            ('linear', nn.Linear(4*width, num_classes)),
        ])
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Loss & optimizer
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train and evaluate contaminated below threshold each epoch
    for epoch in range(1, num_epochs + 1):
        # --- Training ---
        model.train()
        train_loss_sum = 0.0
        for imgs, labels, _ in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            losses = criterion(outputs, labels)
            loss = losses.mean()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item() * imgs.size(0)
        avg_train_loss = train_loss_sum / len(train_loader.dataset)
        print(f"Epoch {epoch}/{num_epochs}, Avg Training Loss: {avg_train_loss:.4f}")

        # --- Compute 25th percentile and fraction contaminated below ---
        model.eval()
        all_losses, all_flags = [], []
        with torch.no_grad():
            for imgs, labels, flags in train_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                batch_losses = criterion(outputs, labels).cpu().numpy()
                all_losses.extend(batch_losses.tolist())
                all_flags.extend(flags.cpu().numpy().tolist())
        all_losses_arr = np.array(all_losses)
        flags_arr = np.array(all_flags)
        threshold = np.percentile(all_losses_arr, 25.5)
        if flags_arr.sum() > 0:
            frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
        else:
            frac_cont_below = 0
        print(f"  Threshold (25.5th pct): {threshold}")
        print(f"  Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Compute per-sample losses for the entire dataset.
    model.eval()
    all_losses = []
    # Using another DataLoader without shuffling (to preserve order).
    full_loader = DataLoader(contaminated_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for images, labels, _ in full_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            losses = criterion(outputs, labels)  # Per-sample losses
            all_losses.extend(losses.cpu().numpy().tolist())

    all_losses_np = np.array(all_losses)
    # Calculate the 25th percentile and filter indices.
    threshold = np.percentile(all_losses_np, percentile)
    if flags_arr.sum() > 0:
        contam_losses_np = all_losses_np[np.array(list(contaminated_dataset.contaminated_set))]
        frac_cont_below = (flags_arr & (all_losses_arr < threshold)).sum() / flags_arr.sum()
    else:
        frac_cont_below = 0
    print(f"Loss Threshold: {threshold:.4f}")
    print(f"Fraction of contaminated below threshold: {frac_cont_below:.4f}\n")
    # Filter out (i.e. remove) points whose loss is below the threshold.
    # We keep only samples with loss >= threshold.
    filtered_indices = [idx for idx, loss_val in enumerate(all_losses_np) if loss_val >= threshold]
    print(f"Filtered dataset size: {len(filtered_indices)} out of {len(all_losses_np)} samples.")
    
    # Create a filtered dataset as a subset using the selected indices.
    filtered_dataset = Subset(contaminated_dataset, filtered_indices)
    return model, threshold, filtered_dataset, all_losses_np

class SmallMLP(nn.Module):
    def __init__(self, input_dim=3*32*32, num_classes=10):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.model(x)

def train_and_filter_dataset_ttll(dataset, percentile=25.0, num_epochs=5, prefilter_model="cnn5"):
    if prefilter_model == "ann5":
        model = SmallMLP().to(device)
    elif prefilter_model == "cnn5":
        num_channels = 3
        width = 32  # image width
        num_classes = 10
        model = nn.Sequential(
            OrderedDict([
                ('conv0', nn.Conv2d(num_channels, width, kernel_size=3, padding=1)),
                ('relu0', nn.ReLU()),
                ('conv1', nn.Conv2d(width, 2*width, kernel_size=3, padding=1)),
                ('relu1', nn.ReLU()),
                ('conv2', nn.Conv2d(2*width, 4*width, kernel_size=3, stride=2, padding=1)),
                ('relu2', nn.ReLU()),
                ('pool0', nn.MaxPool2d(3)),
                ('conv3', nn.Conv2d(4*width, 4*width, kernel_size=3, stride=2, padding=1)),
                ('relu3', nn.ReLU()),
                ('pool1', nn.AdaptiveAvgPool2d(1)),
                ('flatten', nn.Flatten()),
                ('linear', nn.Linear(4*width, num_classes)),
            ])
        ).to(device)
    else:
        raise AttributeError("Prefiltering model not recognized.")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(reduction='none')

    loader = DataLoader(dataset, batch_size=128, shuffle=True)

    all_losses = torch.zeros(len(dataset), num_epochs)
    forget_events = torch.zeros(len(dataset))
    previously_correct = torch.zeros(len(dataset), dtype=torch.bool)
    first_low_loss_epoch = torch.full((len(dataset),), fill_value=num_epochs, dtype=torch.long)

    for epoch in tqdm(range(num_epochs)):
        model.train()
        for inputs, targets, indices, _ in loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss_mean = loss.mean()
            optimizer.zero_grad()
            loss_mean.backward()
            optimizer.step()

            preds = outputs.argmax(dim=1)
            correct = (preds == targets).detach().cpu()
            all_losses[indices, epoch] = loss.detach().cpu()

            # Forgetting events tracking
            previously_correct_batch = previously_correct[indices]
            newly_correct_batch = correct
            forgetting = (previously_correct_batch == True) & (newly_correct_batch == False)
            forget_events[indices] += forgetting.int()
            previously_correct[indices] = newly_correct_batch

            # First low-loss epoch tracking
            low_loss = (loss.detach().cpu() < 0.01)
            for idx, is_low in zip(indices, low_loss):
                if is_low and first_low_loss_epoch[idx] == num_epochs:
                    first_low_loss_epoch[idx] = epoch
    
    num_samples_to_remove = int((percentile / 100.) * len(dataset))

    scores = -first_low_loss_epoch.float()
    topk = scores.topk(num_samples_to_remove, largest=True)
    threshold = torch.min(topk.values)
    selected_indices = topk.indices
    selected_indices_set = set(selected_indices.tolist())
    all_indices = set(range(len(dataset)))
    kept_indices = list(all_indices - selected_indices_set)
    num_samples_to_keep = len(kept_indices)

    contaminated_indices = list(dataset.contaminated_set)
    contaminated_selected = sum(1 for idx in contaminated_indices if idx in selected_indices_set)

    contaminated_percentage = contaminated_selected / len(contaminated_indices) * 100
    strategy = "time_to_low_loss"
    print(f"Strategy: {strategy}")
    print(f"Threshold: {threshold}")
    print(f"Total samples removed: {num_samples_to_remove}")
    print(f"Contaminated samples removed: {contaminated_selected} ({contaminated_percentage:.2f}% of contaminated samples)")
    print(f"Total samples kept: {num_samples_to_keep}")

    filtered_dataset = Subset(dataset, kept_indices)
    print(f"Total samples kept: {len(filtered_dataset)}")

    return model, threshold, filtered_dataset, scores

def get_weighted_sampler(subset, num_classes=10):
    labels = extract_targets(subset)
    class_counts = np.bincount(labels, minlength=num_classes)
    class_weights = 1. / (class_counts + 1e-6)
    sample_weights = [class_weights[label] for label in labels]
    return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

def extract_targets(subset):
    indices = None
    while isinstance(subset, Subset):
        indices = subset.indices if indices is None else [subset.indices[i] for i in indices]
        subset = subset.dataset
    subset = subset.dataset
    if hasattr(subset, "targets"):
        return [subset.targets[i] for i in indices]
    elif hasattr(subset, "labels"):
        return [subset.labels[i] for i in indices]
    else:
        raise AttributeError("Underlying dataset does not have 'targets' or 'labels'.")

# Main training script
def main():
    parser = argparse.ArgumentParser(description='CIFAR10 with shortcut noise and L2 regularization')
    parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
    parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
    parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
    parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
    parser.add_argument("--prefilter_model", help="Type of prefiltering model", type=str, default="cnn5")
    args = parser.parse_args()
    print(args)
    eps = args.eps
    annp = args.annp
    run = args.run
    prefilter_model = args.prefilter_model
    filename = f"cifar10weightedlossv3_shortcut_patchreg_ttll_{prefilter_model}_eps_{eps}_annp_{annp}_run_{run}.npz"
    if os.path.isfile(filename) and (not args.force):
        # File exists
        print(f"File {filename} exists")
        print("=" * 50)
        sys.exit(0)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    base_train = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=None
    )
    # Load and contaminate train (and later split validation) data
    classes_to_contaminate = [0,2,4]
    contamination_rate = eps
    color_mapping = {0: (255, 0, 0), 2: (0, 0, 255), 4: (0, 255, 0)}
    patch_size = 3
    train_dataset = ContaminatedDataset(
        base_train,
        classes_to_contaminate,
        contamination_rate,
        color_mapping,
        patch_size,
        transform
    )
    if annp != 100.0:
        _, _, full_train_dataset, _ = train_and_filter_dataset_ttll(train_dataset, percentile=100.-annp, prefilter_model=prefilter_model)
    else:
        full_train_dataset = train_dataset

    # Split into train and validation sets (80% train, 20% validation)
    train_size = int(0.8 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
    
    # Test set is loaded normally without contamination
    test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    num_channels = 3
    width = 32  # image width
    num_classes = 10

    num_models = 9
    lambda_regs = torch.cat([torch.tensor([0]) ,torch.logspace(-6,1,num_models-1)]).to(device)

    all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_test_f1 = torch.zeros((NUM_EPOCHS, num_models)).to(device)
    all_best_val_test_acc = torch.zeros(num_models).to(device)
    for i, lambda_reg in enumerate(lambda_regs):
        print(lambda_reg)
        # Initialize wandb
        wandb.init(project="cifar10_contaminated_weighted", config={
            "batch_size": BATCH_SIZE,
            "num_epochs": NUM_EPOCHS,
            "learning_rate": LEARNING_RATE,
            "architecture": "custom CNN as provided",
            "contamination_prob": eps,
            "lambda_reg": lambda_reg
        })
        
        model = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                    ("relu0", nn.ReLU()),
                    ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                    ("relu1", nn.ReLU()),
                    ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu2", nn.ReLU()),
                    ("pool0", nn.MaxPool2d(3)),
                    ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                    ("relu3", nn.ReLU()),
                    ("pool1", nn.AdaptiveAvgPool2d(1)),
                    ("flatten", nn.Flatten()),
                    ("linear", nn.Linear(4 * width, num_classes)),
                ]
            )
        ).cuda()  # Move model to GPU if available

        # Loss and optimizer
        class_labels = extract_targets(train_dataset)
        class_counts = np.bincount(class_labels, minlength=num_classes) / len(train_dataset)
        print(np.sum(class_counts))
        weights = 1. / (class_counts + 1e-6)
        weight_tensor = torch.tensor(weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=weight_tensor)
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        best_val_loss = float('inf')
        best_test_acc = 0.0

        # Training loop
        for epoch in range(1, NUM_EPOCHS + 1):
            model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for inputs, targets, _, _ in train_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                inputs.requires_grad = True
                
                outputs = model(inputs)
                logits = outputs
                labels = targets
                loss_ce = torch.nn.functional.cross_entropy(logits, labels)

                # Compute gradient w.r.t. input (for target logits)
                scalar_proxy = logits.gather(1, labels.view(-1, 1)).sum()
                grad_input = torch.autograd.grad(scalar_proxy, inputs, create_graph=True)[0]  # shape: (B, 3, 32, 32)
                # Lower-right 3x3 patch mask
                B, C, H, W = grad_input.shape
                mask = torch.zeros_like(grad_input)
                mask[:, :, H-3:H, W-3:W] = 1.0

                # Masked gradient norm
                grad_patch = grad_input * mask
                grad_norm = grad_patch.view(B, -1).norm(p=2, dim=1).mean()

                # Total loss and update
                loss = loss_ce + lambda_reg * grad_norm
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()

            train_loss /= train_total
            train_acc = train_correct / train_total

            # Validation
            model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            val_all_preds = []
            val_all_labels = []
            with torch.no_grad():
                for inputs, targets, _, _ in val_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    val_total += targets.size(0)
                    val_correct += predicted.eq(targets).sum().item()

            val_loss /= val_total
            val_acc = val_correct / val_total

            # Test evaluation at each epoch to track performance
            model.eval()
            test_loss = 0.0
            test_correct = 0
            test_total = 0
            test_all_preds = []
            test_all_labels = []

            with torch.no_grad():
                for inputs, targets in test_loader:
                    inputs, targets = inputs.cuda(), targets.cuda()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    test_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    test_total += targets.size(0)
                    test_correct += predicted.eq(targets).sum().item()
                    test_all_preds.extend(predicted.cpu().numpy())
                    test_all_labels.extend(targets.cpu().numpy())

            test_loss /= test_total
            test_acc = test_correct / test_total
            test_f1 = f1_score(test_all_labels, test_all_preds, average="macro")

            # Save best test accuracy based on minimal val loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_test_acc = test_acc
            # Log metrics to wandb
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "test_loss": test_loss,
                "test_accuracy": test_acc,
                "test_f1": test_f1,
            })
            all_train_losses[epoch-1, i] = train_loss
            all_train_acc[epoch-1, i] = train_acc
            all_val_losses[epoch-1, i] = val_loss
            all_val_acc[epoch-1, i] = val_acc
            all_test_losses[epoch-1, i] = test_loss
            all_test_acc[epoch-1, i] = test_acc
            all_test_f1[epoch-1, i] = test_f1

            print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
                f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
                f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

        # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
        wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
        all_best_val_test_acc[i] = best_test_acc
        print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

        wandb.finish()
    np.savez(filename, all_train_losses.detach().cpu().numpy(),
             all_train_acc.detach().cpu().numpy(),
             all_val_losses.detach().cpu().numpy(),
             all_val_acc.detach().cpu().numpy(),
             all_test_losses.detach().cpu().numpy(),
             all_test_acc.detach().cpu().numpy(),
             all_best_val_test_acc.detach().cpu().numpy(),
             all_test_f1.detach().cpu().numpy())
    print("=" * 50)
if __name__ == '__main__':
    main()

