import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset, Subset, ConcatDataset
import torchvision
from torchvision import transforms, datasets
import numpy as np
from PIL import Image
from collections import OrderedDict
from scipy.stats import sem
from tqdm import tqdm as tqdm
import torch.nn.functional as F
import torchvision.models as models
import argparse
import os
import sys
import wandb
from collections import defaultdict
import random
import urllib
import zipfile
import shutil

parser = argparse.ArgumentParser(description='CIFAR10 with shortcut noise and L1 regularization')
parser.add_argument('--eps', type=float, default=0.0, help='Contamination ratio (default: 0.0)')
parser.add_argument('--annp', type=float, default=100.0, help='ANNP threshold percentile (default: 100.0)')
parser.add_argument('--run', type=int, default=0, help='Run (default: 0)')
parser.add_argument("--force", help="Force overwriting old run", type=bool, default=False)
parser.add_argument('--size_percent', type=float, default=100.0, help='Initial reduction of dataset (default: 100.0)')
parser.add_argument('--comment', type=str, default="", help='Initial reduction of dataset (default: 100.0)')
args = parser.parse_args()
print(args)
eps = args.eps
annp = args.annp
run = args.run
size_percent = args.size_percent
comment = args.comment
P_NORM = 2.0
NUM_CLASSES = 200
BASE_URL    = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
ZIP_PATH    = "./data/tiny-imagenet-200-v2.zip"
DATA_DIR    = "./data/tiny-imagenet-200-v2"
base_url      = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
raw_zip       = "./data/tiny-imagenet-200-v2.zip"
extract_dir   = "./data/tiny-imagenet-200-v2"
if eps == 0.0 and annp != 100.0:
    print(f"Eps value 0.0 only runs with annp=100.0; received eps={eps}, annp={annp}")
    sys.exit(0)

filename = f"tinyINv3regL{P_NORM}v5_label_partialv2_{size_percent}{comment}_eps_{eps}_annp_{annp}_run_{run}.npz"
if os.path.isfile(filename) and (not args.force):
    # File exists
    print(f"File {filename} exists")
    print("=" * 50)
    sys.exit(0)


# Hyperparameters
num_classes = 10
NUM_EPOCHS = 20
DATA_DIR = "./data"
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Download + Extract
def prepare_data():
    if os.path.isdir(extract_dir):
        print(f"Already prepared in {extract_dir}")
        return
    os.makedirs(os.path.dirname(raw_zip), exist_ok=True)
    print("Downloading TinyImageNet-200")
    urllib.request.urlretrieve(base_url, raw_zip)
    print("Extracting")
    with zipfile.ZipFile(raw_zip, 'r') as z:
        z.extractall(os.path.dirname(extract_dir))
    # Rename and reorganize val
    old = "./data/tiny-imagenet-200"
    if os.path.isdir(old):
        os.rename(old, extract_dir)
    # Fix val annotations
    val_dir     = os.path.join(extract_dir, "val")
    images_dir  = os.path.join(val_dir, "images")
    ann_file    = os.path.join(val_dir, "val_annotations.txt")
    with open(ann_file, "r") as f:
        for line in f:
            img, cls, *_ = line.split()
            os.makedirs(os.path.join(val_dir, cls), exist_ok=True)
            shutil.move(os.path.join(images_dir, img),
                        os.path.join(val_dir, cls, img))
    shutil.rmtree(images_dir)
    print("Done!")

prepare_data()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                         std =[0.2302, 0.2265, 0.2262])
])

# Load original train and val, ignore test
train_orig = datasets.ImageFolder(os.path.join(extract_dir, "train"), transform=transform)
val_orig   = datasets.ImageFolder(os.path.join(extract_dir, "val")  , transform=transform)

# Combine and Resplit
combined = ConcatDataset([train_orig, val_orig])
total_n  = len(combined)
n_train  = int(0.7 * total_n)
n_val    = int(0.1 * total_n)
n_test   = total_n - n_train - n_val

trainval_dataset_full, test_dataset = random_split(
    combined,
    [n_train + n_val, n_test]
)
print("[DEBUG] Datasets loaded")
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=False):
        super().__init__()
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        if pool:
            layers.append(nn.MaxPool2d(2))
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x):
        return F.relu(x + self.block(x))

class ResNet9(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128, pool=True)
        self.res1 = ResidualBlock(128)
        self.conv3 = ConvBlock(128, 256, pool=True)
        self.conv4 = ConvBlock(256, 512, pool=True)
        self.res2 = ResidualBlock(512)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.res1(x) 
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.res2(x)
        x = self.pool(x) 
        x = x.view(x.size(0), -1)
        return self.fc(x) # [B, num_classes]

def add_label_noise(dataset, noise_pct):
    # pull out all the labels
    n = len(dataset)
    orig_targets = [dataset[i][1] for i in range(n)]
    # decide which to corrupt
    k = int(noise_pct * n)
    noise_idx = random.sample(range(n), k)

    # build a new list of targets with noise injected
    noisy_targets = orig_targets.copy()
    for i in noise_idx:
        choices = list(range(NUM_CLASSES))
        choices.remove(noisy_targets[i])
        noisy_targets[i] = random.choice(choices)

    # wrap the original dataset so __getitem__ returns (x, noisy_target)
    class NoisyDataset(Dataset):
        def __init__(self, base_dataset, noisy_targets):
            self.base = base_dataset
            self.targets = noisy_targets
        def __len__(self):
            return len(self.base)
        def __getitem__(self, idx):
            x, _ = self.base[idx]
            return x, self.targets[idx]

    return NoisyDataset(dataset, noisy_targets), noise_idx

def balanced_subsample(dataset, keep_pct):
    # pull out all the labels
    targets = np.array([dataset[i][1] for i in range(len(dataset))])
    indices = []
    for cls in range(NUM_CLASSES):
        cls_idx = np.where(targets == cls)[0].tolist()
        k = int(len(cls_idx) * keep_pct)
        indices += random.sample(cls_idx, k)
    return Subset(dataset, indices)

trainval_dataset_full = balanced_subsample(trainval_dataset_full, size_percent/100.0)
noisy_trainset, noisy_indices = add_label_noise(trainval_dataset_full, eps)
print("[DEBUG] Added noise!")

def train_and_prefilter(noisy_trainset, percentile):
    batch_size = 128
    epochs = 10
    learning_rate = 1e-3
    num_classes = 200
    num_channels = 3
    width = 64
    class IndexedDataset(Dataset):

        def __init__(self, base_dataset):
            self.dataset = base_dataset

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            # pull x, y from underlying dataset
            x, y = self.dataset[idx]
            # return data, label, and the index within this dataset
            return x, y, idx
    
    noisy_trainset_idx = IndexedDataset(noisy_trainset)
    trainloader = DataLoader(noisy_trainset_idx, batch_size=batch_size, shuffle=True,  num_workers=8, pin_memory=True)

    # Model: ResNet-9
    model = ResNet9(in_channels=3, num_classes=200)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    criterion = nn.CrossEntropyLoss()
    criterion_no_reduce = nn.CrossEntropyLoss(reduction='none')

    # Training
    model.train()
    for epoch in range(epochs):
        for inputs, labels, _ in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Compute individual training losses
        model.eval()
        losses = np.zeros(len(noisy_trainset))
        

        with torch.no_grad():
            for inputs, labels, indices in DataLoader(noisy_trainset_idx, batch_size=batch_size, shuffle=False):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                batch_losses = criterion_no_reduce(outputs, labels)
                losses[indices.numpy()] = batch_losses.cpu().numpy()

        # Analysis
        threshold = np.percentile(losses, percentile)
        selected_indices = np.where(losses <= threshold)[0]
        contaminated_losses = losses[noisy_indices]
        subset_dataset = Subset(noisy_trainset, selected_indices)
        fraction_above_threshold = np.mean(contaminated_losses > threshold)
        print(f"Subset fraction: {len(subset_dataset)} points out of {len(noisy_trainset)} ({100 * len(subset_dataset) / len(noisy_trainset):.2f}%)")
        print(f"Epoch {epoch}")
        print(f"70th percentile of all losses: {threshold:.4f}")
        print(f"Fraction of contaminated points with loss above threshold: {fraction_above_threshold:.4f}")
    
    return subset_dataset

if annp != 100.:
    full_train_dataset = train_and_prefilter(noisy_trainset=noisy_trainset, percentile=annp)
else:
    full_train_dataset = noisy_trainset
print("Prefiltered dataset!")

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,  num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,  num_workers=8, pin_memory=True)

# Define the CNN model
num_channels = 3
width = 64 
num_classes = 200

num_models = 9
lambda_regs = torch.logspace(-6,-2.83,num_models-1).to(device) #tinyIN

all_train_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_train_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_val_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_losses = torch.zeros((NUM_EPOCHS, num_models)).to(device)
all_test_acc = torch.zeros((NUM_EPOCHS, num_models)).to(device)

all_best_val_test_acc = torch.zeros(num_models).to(device)
for i, lambda_reg in enumerate(lambda_regs):
    print(lambda_reg)
    # Initialize wandb
    wandb.init(project="tinyIN_contaminated_weighted", config={
        "batch_size": BATCH_SIZE,
        "num_epochs": NUM_EPOCHS,
        "learning_rate": LEARNING_RATE,
        "architecture": "custom CNN as provided",
        "contamination_prob": eps,
        "p_norm": P_NORM,
        "lambda_reg": lambda_reg
    })
    
    model = nn.Sequential(
        OrderedDict(
            [
                ("conv0", nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)),
                ("relu0", nn.ReLU()),
                ("conv1", nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)),
                ("relu1", nn.ReLU()),
                ("conv2", nn.Conv2d(2 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu2", nn.ReLU()),
                ("pool0", nn.MaxPool2d(3)),
                ("conv3", nn.Conv2d(4 * width, 4 * width, kernel_size=3, stride=2, padding=1)),
                ("relu3", nn.ReLU()),
                ("pool1", nn.AdaptiveAvgPool2d(1)),
                ("flatten", nn.Flatten()),
                ("linear", nn.Linear(4 * width, num_classes)),
            ]
        )
    ).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    best_test_acc = 0.0

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            optimizer.zero_grad()
            outputs = model(inputs)
            ce_loss = criterion(outputs, targets)
            lp_reg = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
            lp_reg = lp_reg.pow(1.0 / P_NORM)
            loss = ce_loss + lambda_reg * lp_reg
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        train_loss /= train_total
        train_acc = train_correct / train_total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_all_preds = []
        val_all_labels = []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_loss /= val_total
        val_acc = val_correct / val_total

        # Test evaluation at each epoch to track performance
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        test_all_preds = []
        test_all_labels = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.cuda(), targets.cuda()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
                test_all_preds.extend(predicted.cpu().numpy())
                test_all_labels.extend(targets.cpu().numpy())

        test_loss /= test_total
        test_acc = test_correct / test_total

        # Save best test accuracy based on minimal val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_test_acc = test_acc
        lp_norm = sum(torch.sum(torch.abs(param) ** P_NORM) for param in model.parameters() if param.requires_grad)
        lp_norm = lp_reg.pow(1.0 / P_NORM)
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc,
            "test_loss": test_loss,
            "test_accuracy": test_acc,
            "lp_norm": lp_norm,
        })
        all_train_losses[epoch-1, i] = train_loss
        all_train_acc[epoch-1, i] = train_acc
        all_val_losses[epoch-1, i] = val_loss
        all_val_acc[epoch-1, i] = val_acc
        all_test_losses[epoch-1, i] = test_loss
        all_test_acc[epoch-1, i] = test_acc

        print(f"Epoch [{epoch}/{NUM_EPOCHS}] "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    # Log the best test accuracy (i.e., test accuracy at epoch with minimal validation loss)
    wandb.log({"best_test_accuracy_at_best_val_epoch": best_test_acc})
    all_best_val_test_acc[i] = best_test_acc
    print(f"\nBest test accuracy (at minimal val loss): {best_test_acc:.4f}")

    wandb.finish()
np.savez(filename, all_train_losses.detach().cpu().numpy(), all_train_acc.detach().cpu().numpy(), all_val_losses.detach().cpu().numpy(), all_val_acc.detach().cpu().numpy(), all_test_losses.detach().cpu().numpy(), all_test_acc.detach().cpu().numpy(), all_best_val_test_acc.detach().cpu().numpy())#, all_test_f1.detach().cpu().numpy())
print("=" * 50)
