import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset, Subset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
from collections import OrderedDict
from scipy.stats import sem
from tqdm import tqdm as tqdm
import torch.nn.functional as F
import torchvision.models as models
import argparse
import os
import sys
import wandb
from collections import defaultdict
from PIL import Image

parser = argparse.ArgumentParser(description='CIFAR10 with label noise and L2 regularization')
parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
parser.add_argument('--comment', type=str, default="", help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
comment = args.comment
P_NORM = 2.0
if eps == 0.0 and annp != 100.0:
    print(f"Eps value 0.0 only runs with annp=100.0; received eps={eps}, annp={annp}")
    sys.exit(0)

filename = f"cifar10regL{P_NORM}v5conflearnv2_label_partialv2_{size_percent}{comment}_eps_{eps}_annp_{annp}_run_{run}.npz"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)


# Hyperparameters
num_classes = 10
NUM_EPOCHS = 40
DATA_DIR = "./data"
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

class NoisyCIFAR10(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label, idx

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = ResidualBlock(128)
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = ResidualBlock(512)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)       # [B, 64, 32, 32]
        x = self.conv2(x)       # [B, 128, 16, 16]
        x = self.res1(x)        # [B, 128, 16, 16]
        x = self.conv3(x)       # [B, 256, 8, 8]
        x = self.conv4(x)       # [B, 512, 4, 4]
        x = self.res2(x)        # [B, 512, 4, 4]
        x = self.pool(x)        # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        return self.fc(x)       # [B, num_classes]
    
transform = transforms.ToTensor()
trainset_clean = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset_clean = BalancedCIFAR10(train=True, download=True, transform=None, percentage=size_percent / 100)
original_labels = np.array(trainset_clean.targets)

# Introduce uniform label noise
noisy_labels = original_labels.copy()
num_noisy = int(eps * len(original_labels))
noisy_indices = np.random.choice(len(original_labels), size=num_noisy, replace=False)

for idx in noisy_indices:
    true_label = original_labels[idx]
    choices = list(range(num_classes))
    choices.remove(true_label)
    noisy_labels[idx] = np.random.choice(choices)



noisy_trainset = NoisyCIFAR10(trainset_clean.data, noisy_labels, transform=transform)

def train_and_prefilter_conflearn(noisy_trainset, percentile):

    model = ResNet9(num_classes=num_classes).to(device)
    train_loader = DataLoader(noisy_trainset, batch_size=128, shuffle=True, num_workers=2)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Train the model on the full noisy dataset
    model.train()
    for epoch in range(5):
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{5}")
        for images, labels, _ in pbar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'Loss': total_loss / len(train_loader)})
    
    # Get predictions and confidence scores for the entire dataset
    model.eval()
    pred_loader = DataLoader(noisy_trainset, batch_size=128, shuffle=False, num_workers=2)
    
    all_indices = []
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels, indices in tqdm(pred_loader):
            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            
            all_indices.extend(indices.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_indices = np.array(all_indices)
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    
    # Calculate confidence: probability of the given label
    confidence_scores = all_probs[np.arange(len(all_labels)), all_labels]

    kept_indices = []
    for c in range(num_classes):
        class_mask = (all_labels == c)
        class_indices = all_indices[class_mask]
        class_scores = confidence_scores[class_mask]

        # If percentile is 100, keep all samples for this class
        if percentile >= 100.0:
            kept_indices.extend(class_indices.tolist())
            continue
    
        num_to_keep = int(len(class_indices) * (percentile / 100.0))
        
        if num_to_keep == 0:
            print(f"Warning: 0 samples kept for class {c}. Consider a higher percentile.")
            continue

        # Find the indices of the samples with the highest scores
        k_th_largest_idx = np.argpartition(-class_scores, num_to_keep-1)[ :num_to_keep]
        
        indices_to_keep_for_class = class_indices[k_th_largest_idx]
        kept_indices.extend(indices_to_keep_for_class.tolist())

    print(f"Original dataset size: {len(noisy_trainset)}")
    print(f"Filtered dataset size: {len(kept_indices)}")
    
    clean_subset = Subset(noisy_trainset, kept_indices)
    return clean_subset

if annp != 100.:
    full_train_dataset = train_and_prefilter_conflearn(noisy_trainset=noisy_trainset, percentile=annp)
else:
    full_train_dataset = noisy_trainset


train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Test set is loaded normally without contamination
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Define the CNN model
num_channels = 3
width = 32 
num_classes = 10

num_models = 9
lambda_regs = torch.logspace(-4,-1.64,num_models-1).to(device)
all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_best_val_test_acc = torch.zeros(num_models).to(device)
for i, lambda_reg in enumerate(lambda_regs):
    print(lambda_reg)
    wandb.init(project="cifar10_contaminated_weighted", config={
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "architecture": "custom CNN as provided",
        "contamination_prob": eps,
        "p_norm": P_NORM,
        "lambda_reg": lambda_reg
    })
    
    model = nn.Sequential(
        OrderedDict(
            [
                ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                ("relu0", nn.ReLU()),
                ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                ("relu1", nn.ReLU()),
                ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu2", nn.ReLU()),
                ("pool0", nn.MaxPool2d(3)),
                ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu3", nn.ReLU()),
                ("pool1", nn.AdaptiveAvgPool2d(1)),
                ("flatten", nn.Flatten()),
                ("linear", nn.Linear(4 * width, num_classes)),
            ]
        )
    ).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    best_test_acc = 0.0

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets, _ in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            ce_loss = criterion(outputs, targets)
            lp_reg = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
            lp_reg = lp_reg.pow(1.0 / P_NORM)
            loss = ce_loss + lambda_reg * lp_reg
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_all_preds = []
        val_all_labels = []
        with torch.no_grad():
            for inputs, targets, _ in val_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Test evaluation
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        test_all_preds = []
        test_all_labels = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
                test_all_preds.extend(predicted.cpu().numpy())
                test_all_labels.extend(targets.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        # Save best test accuracy based on minimal val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_test_acc = test_acc
        lp_norm = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
        lp_norm = lp_reg.pow(1.0 / P_NORM)
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "lp_norm": lp_norm,
        })
        all_train_losses[epoch-1, i] = train_loss
        all_train_acc[epoch-1, i] = train_acc
        all_val_losses[epoch-1, i] = val_loss
        all_val_acc[epoch-1, i] = val_acc
        all_test_losses[epoch-1, i] = test_loss
        all_test_acc[epoch-1, i] = test_acc

        print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Log the best test accuracy
    wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
    all_best_val_test_acc[i] = best_test_acc
    print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

    wandb.finish()
np.savez(filename, all_train_losses.detach().cpu().numpy(), all_train_acc.detach().cpu().numpy(), all_val_losses.detach().cpu().numpy(), all_val_acc.detach().cpu().numpy(), all_test_losses.detach().cpu().numpy(), all_test_acc.detach().cpu().numpy(), all_best_val_test_acc.detach().cpu().numpy())#, all_test_f1.detach().cpu().numpy())
print("=" * 50)