import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset, Subset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
from collections import OrderedDict
from scipy.stats import sem
from tqdm import tqdm as tqdm
import torch.nn.functional as F
import torchvision.models as models
import argparse
import os
import sys
import wandb
from collections import defaultdict

parser = argparse.ArgumentParser(description='CIFAR10 with label noise and L2 regularization')
parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
parser.add_argument('--comment', type=str, default="", help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
comment = args.comment
P_NORM = 2.0
if eps == 0.0 and annp != 100.0:
    print(f"Eps value 0.0 only runs with annp=100.0; received eps={eps}, annp={annp}")
    sys.exit(0)

filename = f"cifar10humanv2regL{P_NORM}v5_label_partialv2_{size_percent}{comment}_eps_{eps}_annp_{annp}_run_{run}.npz"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)


num_classes = 10
NUM_EPOCHS = 40
DATA_DIR = "./data"
BATCH_SIZE = 128
LEARNING_RATE = 1e-3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BalancedCIFAR10(Dataset):
    def __init__(self, root='./data', train=True, transform=None, download=True, percentage=0.75):
        self.percentage = percentage
        self.transform = transform
        self.train = train

        # Load full dataset
        full_dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
        cifar10_train = full_dataset
        data = cifar10_train.data
        noisy_path="./data/CIFAR-10_human.pt"
        noisy_dict = torch.load(noisy_path, weights_only=False)
        noisy_labels = noisy_dict["random_label1"]

        data_tensor = torch.tensor(data).permute(0, 3, 1, 2).float()
        label_tensor = torch.tensor(noisy_labels)

        dataset = TensorDataset(data_tensor, label_tensor)
        self.data, self.targets = self._create_balanced_subset(full_dataset)

    def _create_balanced_subset(self, dataset):

        # Group indices by class
        targets = np.array(dataset.targets)
        class_indices = defaultdict(list)
        for idx, label in enumerate(targets):
            class_indices[label].append(idx)

        # Sample a balanced subset
        selected_indices = []
        for class_id, indices in class_indices.items():
            n_samples = int(len(indices) * self.percentage)
            sampled = np.random.choice(indices, n_samples, replace=False)
            selected_indices.extend(sampled)

        # Shuffle overall indices to avoid class clustering
        np.random.shuffle(selected_indices)

        # Extract samples
        data = [dataset.data[i] for i in selected_indices]
        targets = [targets[i] for i in selected_indices]

        return data, targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        img, label = self.data[index], self.targets[index]

        # Convert to PIL image if necessary
        from PIL import Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label

class NoisyCIFAR10(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label, idx  # include index

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = ResidualBlock(128)
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = ResidualBlock(512)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)       # [B, 64, 32, 32]
        x = self.conv2(x)       # [B, 128, 16, 16]
        x = self.res1(x)        # [B, 128, 16, 16]
        x = self.conv3(x)       # [B, 256, 8, 8]
        x = self.conv4(x)       # [B, 512, 4, 4]
        x = self.res2(x)        # [B, 512, 4, 4]
        x = self.pool(x)        # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # [B, 512]
        return self.fc(x)       # [B, num_classes]
    
transform = transforms.ToTensor()
trainset_clean = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset_clean = BalancedCIFAR10(train=True, download=True, transform=None, percentage=size_percent / 100)
original_labels = np.array(trainset_clean.targets)
noisy_trainset = NoisyCIFAR10(trainset_clean.data, original_labels, transform=transform)

def train_and_prefilter(noisy_trainset, percentile):
    batch_size = 128
    epochs = 7
    learning_rate = 1e-3
    num_classes = 10
    num_channels = 3
    width = 32
    trainloader = DataLoader(noisy_trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    # Model: ResNet-9
    model = ResNet9(in_channels=3, num_classes=10)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Optimizer and Loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    criterion = nn.CrossEntropyLoss()
    criterion_no_reduce = nn.CrossEntropyLoss(reduction='none')

    # Training
    model.train()
    for epoch in range(epochs):
        for inputs, labels, _ in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Compute individual training losses
        model.eval()
        losses = np.zeros(len(noisy_trainset))
        

        with torch.no_grad():
            for inputs, labels, indices in DataLoader(noisy_trainset, batch_size=batch_size, shuffle=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                batch_losses = criterion_no_reduce(outputs, labels)
                losses[indices.numpy()] = batch_losses.cpu().numpy()

        # Analysis
        threshold = np.percentile(losses, percentile)
        selected_indices = np.where(losses <= threshold)[0]
        subset_dataset = Subset(noisy_trainset, selected_indices)
        print(f"Subset fraction: {len(subset_dataset)} points out of {len(noisy_trainset)} ({100 * len(subset_dataset) / len(noisy_trainset):.2f}%)")
        print(f"Epoch {epoch}")
        print(f"70th percentile of all losses: {threshold:.4f}")
    
    return subset_dataset

if annp != 100.:
    full_train_dataset = train_and_prefilter(noisy_trainset=noisy_trainset, percentile=annp)
else:
    full_train_dataset = noisy_trainset

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Test set is loaded normally without contamination
test_dataset = datasets.CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Define the CNN model
num_channels = 3
width = 32
num_classes = 10

# num_models = 51
num_models = 9
lambda_regs = torch.logspace(-4,-1.64,num_models-1).to(device)

all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)

all_best_val_test_acc = torch.zeros(num_models).to(device)
for i, lambda_reg in enumerate(lambda_regs):
    print(lambda_reg)
    # Initialize wandb
    wandb.init(project="cifar10_contaminated_weighted", config={
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "architecture": "custom CNN as provided",
        "contamination_prob": eps,
        "p_norm": P_NORM,
        "lambda_reg": lambda_reg
    })
    
    model = nn.Sequential(
        OrderedDict(
            [
                ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                ("relu0", nn.ReLU()),
                ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                ("relu1", nn.ReLU()),
                ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu2", nn.ReLU()),
                ("pool0", nn.MaxPool2d(3)),
                ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu3", nn.ReLU()),
                ("pool1", nn.AdaptiveAvgPool2d(1)),
                ("flatten", nn.Flatten()),
                ("linear", nn.Linear(4 * width, num_classes)),
            ]
        )
    ).cuda()  # Move model to GPU if available

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    best_test_acc = 0.0

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets, _ in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            ce_loss = criterion(outputs, targets)
            lp_reg = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
            lp_reg = lp_reg.pow(1.0 / P_NORM)
            loss = ce_loss + lambda_reg * lp_reg
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_all_preds = []
        val_all_labels = []
        with torch.no_grad():
            for inputs, targets, _ in val_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Test evaluation at each epoch to track performance
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        test_all_preds = []
        test_all_labels = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
                test_all_preds.extend(predicted.cpu().numpy())
                test_all_labels.extend(targets.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        # Save best test accuracy based on minimal val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_test_acc = test_acc
        lp_norm = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
        lp_norm = lp_reg.pow(1.0 / P_NORM)
        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "lp_norm": lp_norm,
        })
        all_train_losses[epoch-1, i] = train_loss
        all_train_acc[epoch-1, i] = train_acc
        all_val_losses[epoch-1, i] = val_loss
        all_val_acc[epoch-1, i] = val_acc
        all_test_losses[epoch-1, i] = test_loss
        all_test_acc[epoch-1, i] = test_acc

        print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
    wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
    all_best_val_test_acc[i] = best_test_acc
    print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

    wandb.finish()
np.savez(filename, all_train_losses.detach().cpu().numpy(), all_train_acc.detach().cpu().numpy(), all_val_losses.detach().cpu().numpy(), all_val_acc.detach().cpu().numpy(), all_test_losses.detach().cpu().numpy(), all_test_acc.detach().cpu().numpy(), all_best_val_test_acc.detach().cpu().numpy())#, all_test_f1.detach().cpu().numpy())
print("=" * 50)