"""
Noisy Transfer Learning Evaluation

Evaluates representation robustness by training and testing with noisy images.
Supports Gaussian noise and Salt-and-Pepper noise at various intensities.

This evaluation probes the noise tolerance of learned representations,
a key functional property of biological vision systems.

Usage:
    python main_noisy.py --spiking --arch resnet18 --ckpt-path <path> --noise-ratio 0.2 --noise-type gaussian
"""

from __future__ import print_function

import os
import sys
import argparse
import random
import numpy as np
import logging

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import save_config_file, save_checkpoint, AverageMeter, accuracy
from networks.resnet_ann import SupCEResNet, LinearClassifier
from networks.resnet_snn import SupCEResNetSNN, LinearClassifierSNN, add_dimention_distribute


def parse_option():
    parser = argparse.ArgumentParser('Noisy Transfer Learning')

    parser.add_argument('--spiking', action='store_true', help='Use SNN model')
    parser.add_argument('--timesteps', type=int, default=4)
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
    parser.add_argument('--dataset-name', type=str, default='cifar10')
    parser.add_argument('-j', '--workers', default=16, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('-b', '--batch-size', default=512, type=int)
    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, dest='weight_decay')
    
    # Noise parameters
    parser.add_argument('--noise-ratio', type=float, default=0.2,
                        choices=[0.0, 0.2, 0.4, 0.6, 0.8],
                        help='Noise intensity')
    parser.add_argument('--noise-type', type=str, default='gaussian',
                        choices=['gaussian', 'salt_pepper'],
                        help='Type of noise to apply')
    parser.add_argument('--seed', type=int, default=42)
    
    # Other settings
    parser.add_argument('--save-path', type=str, default='./save/Noisy')
    parser.add_argument('--ckpt-path', type=str, required=True)
    parser.add_argument('--gpu-ids', type=str, default='0')

    opt = parser.parse_args()

    if opt.dataset_name == 'cifar10':
        opt.data_folder = "~/projects/Datasets/CIFAR10/"
        opt.img_size = 32
        opt.n_cls = 10

    if opt.spiking:
        opt.log_path = os.path.join(opt.save_path,
            f'SNN_{opt.arch}_noise{opt.noise_ratio}_{opt.noise_type}')
    else:
        opt.log_path = os.path.join(opt.save_path,
            f'ANN_{opt.arch}_noise{opt.noise_ratio}_{opt.noise_type}')
    
    if not os.path.isdir(opt.log_path):
        os.makedirs(opt.log_path)

    return opt


class NoisyImageDataset(torch.utils.data.Dataset):
    """
    Dataset wrapper that adds noise to images.
    
    Supports:
    - Gaussian noise: Additive noise sampled from N(0, sigma)
    - Salt-and-pepper noise: Random pixels set to 0 or 1
    """
    def __init__(self, original_dataset, noise_ratio, noise_type):
        self.original_dataset = original_dataset
        self.noise_ratio = noise_ratio
        self.noise_type = noise_type

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image, label = self.original_dataset[idx]
        
        if isinstance(image, torch.Tensor):
            image_np = image.numpy()
        else:
            image_np = np.array(image)

        if self.noise_type == 'gaussian':
            # Gaussian noise with standard deviation = noise_ratio
            noise = np.random.normal(0, self.noise_ratio, image_np.shape)
            noisy_image = np.clip(image_np + noise, 0, 1)
            
        elif self.noise_type == 'salt_pepper':
            # Salt-and-pepper noise
            noisy_image = image_np.copy()
            # Salt (white pixels)
            salt_mask = np.random.random(image_np.shape) < (self.noise_ratio / 2)
            noisy_image[salt_mask] = 1
            # Pepper (black pixels)
            pepper_mask = np.random.random(image_np.shape) < (self.noise_ratio / 2)
            noisy_image[pepper_mask] = 0
        else:
            noisy_image = image_np

        noisy_image = torch.from_numpy(noisy_image.astype(np.float32))
        return noisy_image, label


def create_noisy_dataset(dataset, noise_ratio, noise_type, seed=42):
    """Create a dataset with noisy images."""
    if noise_ratio == 0.0:
        return dataset
    
    random.seed(seed)
    np.random.seed(seed)
    
    return NoisyImageDataset(dataset, noise_ratio, noise_type)


def set_loader(opt):
    """Setup data loaders with noisy images."""
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = datasets.CIFAR10(root=opt.data_folder, transform=train_transform, download=True)
    val_dataset = datasets.CIFAR10(root=opt.data_folder, train=False, transform=val_transform)

    # Apply noise to both train and val
    train_dataset = create_noisy_dataset(train_dataset, opt.noise_ratio, opt.noise_type, opt.seed)
    val_dataset = create_noisy_dataset(val_dataset, opt.noise_ratio, opt.noise_type, opt.seed)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=opt.batch_size, shuffle=False,
        num_workers=opt.workers, pin_memory=True)

    print(f"Training with {opt.noise_type} noise at ratio {opt.noise_ratio}")
    return train_loader, val_loader


def set_model(opt):
    """Load pretrained encoder and initialize linear classifier."""
    if opt.spiking:
        model = SupCEResNetSNN(name=opt.arch, timestep=opt.timesteps, 
                               num_classes=opt.n_cls, mode='encoder')
        classifier = LinearClassifierSNN(name=opt.arch, num_classes=opt.n_cls)
    else:
        model = SupCEResNet(name=opt.arch, num_classes=opt.n_cls, mode='encoder')
        classifier = LinearClassifier(name=opt.arch, num_classes=opt.n_cls)

    model = torch.nn.DataParallel(model)
    model.cuda()
    
    # Load pretrained weights
    checkpoint = torch.load(opt.ckpt_path, map_location="cpu", weights_only=False)
    if 'model' in checkpoint:
        model.load_state_dict(checkpoint['model'])
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'])
    print(f"Loaded pretrained model from {opt.ckpt_path}")
    
    # Freeze encoder
    for param in model.parameters():
        param.requires_grad = False
    
    classifier.cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    return model, classifier, criterion


def train_epoch(model, classifier, criterion, train_loader, optimizer, opt):
    """Train classifier for one epoch."""
    model.eval()
    classifier.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    for images, labels in train_loader:
        images = images.cuda()
        labels = labels.cuda()
        bsz = labels.shape[0]

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
            features = model(images)
        
        if opt.spiking:
            output = classifier(features.detach())
            output = output.mean(1)
        else:
            output = classifier(features.detach())
        
        loss = criterion(output, labels)

        losses.update(loss.item(), bsz)
        acc1, _ = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    return losses.avg, top1.avg


def validate(model, classifier, val_loader, opt):
    """Evaluate classifier on noisy validation set."""
    model.eval()
    classifier.eval()

    top1 = AverageMeter()

    for images, labels in val_loader:
        images = images.cuda()
        labels = labels.cuda()

        with torch.no_grad():
            if opt.spiking:
                images = add_dimention_distribute(images, opt.timesteps)
                output = classifier(model(images))
                output = output.mean(1)
            else:
                output = classifier(model(images))

        acc1, _ = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], labels.shape[0])
    
    return top1.avg


def main():
    opt = parse_option()
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
    
    logging.basicConfig(
        filename=os.path.join(opt.log_path, 'training.log'),
        level=logging.DEBUG
    )

    train_loader, val_loader = set_loader(opt)
    model, classifier, criterion = set_model(opt)

    optimizer = torch.optim.Adam(classifier.parameters(), opt.lr, weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader), eta_min=0)

    print(f"Starting noisy training with {opt.noise_type} noise at ratio {opt.noise_ratio}")
    
    best_val_acc = 0.0
    for epoch in range(opt.epochs):
        loss, train_acc = train_epoch(model, classifier, criterion, train_loader, optimizer, opt)
        val_acc = validate(model, classifier, val_loader, opt)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
        
        if epoch >= 10:
            scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch}: Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%, Best={best_val_acc:.2f}%")

    print(f"\nNoisy evaluation complete")
    print(f"Noise type: {opt.noise_type}, ratio: {opt.noise_ratio}")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")


if __name__ == '__main__':
    main()
