# example usage: python vision_begin.py generate --model loop --dataset imagenet-val --batch_size 128 --num_workers 16 --device cuda:0 --val_json_dir ./data/imagenet_frozen_status_v0
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
import argparse
import timm
import os
from datasets import load_dataset
from pathlib import Path
import pandas as pd
import json
import urllib.request
import PIL
import random
import re
import io
import sys
from dataclasses import dataclass
from collections import defaultdict
from datetime import datetime
import torch.nn.functional as F
import copy
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy

ALL_MODELS = [
    'resnet18',
    'resnet34',
    'resnet50',
    'resnet101',
    'resnet152',
    'resnext50_32x4d',
    'resnext101_32x8d',
    'wide_resnet50_2',
    'wide_resnet101_2',
    'densenet121',
    'densenet169',
    'densenet201',
    'densenet161',
    'inception_v3',
    'mnasnet_100',
    'mobilenetv2_100',
    'mobilenetv3_large_100',
    'mobilenetv3_small_100',
    'efficientnetv2_rw_s',
    'efficientnetv2_rw_m',
    'efficientnetv2_rw_l',
    'regnetx_002',
    'regnetx_004',
    'regnetx_006',
    'regnetx_008',
    'regnety_008',
    'regnety_016',
    'regnety_032',
    'convnext_tiny',
    'convnext_small',
    'convnext_base',
    'convnext_large',
    'vit_small_patch16_224',
    'vit_base_patch16_224',
    'deit_small_patch16_224',
    'deit_base_patch16_224',
    'swin_tiny_patch4_window7_224',
    'swin_base_patch4_window7_224',
    # 'tf_efficientnet_b0',
    # 'tf_efficientnet_b1',
    # 'tf_efficientnet_b2',
    # 'tf_efficientnet_b3',
    # 'tf_efficientnet_b4',
    # 'tf_efficientnet_b5',
    # 'tf_efficientnet_b6',
    # 'tf_efficientnet_b7',
    # 'tf_efficientnet_b8',
]


def create_dataloader(dataset_name, split, batch_size, data_dir, incorrect_val_samples: tuple[list[Path], list[str]] = None, data_size = -1):
    if split == 'train':
        is_train = True
        split_dir = 'train'
    else:
        is_train = False
        split_dir = 'val'
    data_dir = os.path.join(data_dir, split_dir)
    num_threads = 32
    device_id = 0
    seed = 12 + device_id

    if is_train:
        pipe = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed)
        with pipe:
            jpegs, labels = fn.readers.file(name='Reader', file_root=data_dir, random_shuffle=True)
            images = fn.decoders.image_random_crop(jpegs, device='mixed', output_type=types.RGB,
                                                   random_area=[0.08, 1.0], random_aspect_ratio=[0.75, 1.333])
            images = fn.resize(images, resize_x=256, resize_y=256)
            mirror = fn.random.coin_flip(probability=0.5)
            images = fn.crop_mirror_normalize(images, device='gpu', dtype=types.FLOAT,
                                              output_layout=types.NHWC, crop=(224, 224),
                                              mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                              mirror=mirror)
            labels = labels.gpu()
            labels_dict = None
            pipe.set_outputs(images, labels)
            pipe.build()

            output_map = ['images', 'labels']
            dali_iter = DALIGenericIterator(
                [pipe],
                output_map=output_map,
                last_batch_policy=LastBatchPolicy.DROP,
                auto_reset=True,
                size=data_size
            )
    else:
        pipe = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed)
        with pipe:
            labels_dict = {i: label for i, label in enumerate(incorrect_val_samples[1])}
            jpegs, labels = fn.readers.file(name='Reader', random_shuffle=False, files=incorrect_val_samples[0], labels=list(range(len(labels_dict))))
            images = fn.decoders.image(jpegs, device='mixed', output_type=types.RGB)
            images = fn.resize(images, resize_shorter=256)
            images = fn.crop_mirror_normalize(images, device='gpu', dtype=types.FLOAT,
                                              output_layout=types.NHWC, crop=(224, 224),
                                              mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
            pipe.set_outputs(images, labels)
            pipe.build()

            output_map = ['images', 'labels']
            dali_iter = DALIGenericIterator(
                [pipe],
                output_map=output_map,
                reader_name='Reader',
                last_batch_policy=LastBatchPolicy.PARTIAL,
                auto_reset=True,
            )

    return dali_iter, labels_dict

def get_dataloaders(batch_size = 256, model_name = None, val_num_samples = -1, train_size = -1):
    _IMAGENET_CLASS_INDEX_URL = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
    with urllib.request.urlopen(_IMAGENET_CLASS_INDEX_URL) as r:
        class_idx = json.load(io.TextIOWrapper(r, encoding="utf-8"))
    imagenet_n0_to_idx = {v[0]: int(k) for k, v in class_idx.items()}

    if model_name is not None:
        json_path = Path('./data/imagenet_frozen_status/')
        json_files = list(json_path.glob(f'val_{model_name}_*.json'))
        assert len(json_files) > 0, f"No json files found for {model_name}"
        j_fn = sorted(json_files)[-1]
        with open(j_fn, 'r') as f:
            val_json = json.load(f)
        val_df = pd.DataFrame(val_json)
        val_df = val_df[~val_df['is_correct']].reset_index(drop=True)
        val_df['label_idx'] = val_df['target']
        incorrect_val_file_names: list[tuple[str, str]] = val_df['image_path'].apply(lambda x: x.split('/')[-2:])
        if val_num_samples > 0 and val_num_samples < len(incorrect_val_file_names):
            print(f'Using only {val_num_samples} out of {len(incorrect_val_file_names)} incorrect samples')
            incorrect_val_file_names = incorrect_val_file_names[:val_num_samples]
        else:
            print(f'Using all {len(incorrect_val_file_names)} incorrect samples')
        imagenet_dir = Path('/storage/ANON/datasets/pytorch_imagenet_data')
        paths = [(imagenet_dir / 'val' / fn[0] / fn[1]).with_suffix('.jpeg') for fn in incorrect_val_file_names]
        labels = [imagenet_n0_to_idx[fn[0]] for fn in incorrect_val_file_names]
        incorrect_val_samples = (paths, labels)
    else:
        # use all imagenet val
        imagenet_dir = Path('/storage/ANON/datasets/pytorch_imagenet_data')
        imagenet_val_dir = imagenet_dir / 'val'
        paths = list(imagenet_val_dir.glob('**/*.jpeg'))
        labels = [imagenet_n0_to_idx[fn.parent.name] for fn in paths]
        incorrect_val_samples = (paths, labels)

    train_loader, _ = create_dataloader('imagenet', 'train', batch_size, imagenet_dir, data_size=train_size)
    val_loader, labels_dict = create_dataloader('imagenet', 'val', batch_size, imagenet_dir, incorrect_val_samples=incorrect_val_samples, data_size=-1)
    return train_loader, val_loader, labels_dict


def parse_args():
    parser = argparse.ArgumentParser(description='Vision Classifier Training')
    subparsers = parser.add_subparsers(dest='mode', required=True)

    generate_subparser = subparsers.add_parser('generate')
    generate_subparser.add_argument('--model', type=str, default='resnet50', help='Pretrained model from timm (default: resnet50)')
    generate_subparser.add_argument('--dataset', type=str, default='imagenet-val',help='Dataset to use (default: imagenet-val)')
    generate_subparser.add_argument('--batch_size', type=int, default=128, help='Batch size (default: 128)')
    generate_subparser.add_argument('--data_dir', type=str, default='./data', help='Data directory (default: ./data)')
    generate_subparser.add_argument('--num_workers', type=int, default=16, help='Number of data loading workers (default: 16)')
    generate_subparser.add_argument('--device', type=str, default='cuda:0', help='Device to use (cuda/cpu, default: cuda:0)')
    generate_subparser.add_argument('--val_json_dir', type=str, default=None, help='Directory to save val json (default: None)')

    overfit_subparser = subparsers.add_parser('overfit')
    overfit_subparser.add_argument('--model', type=str, default='resnet50', help='Pretrained model from timm (default: resnet50)')
    overfit_subparser.add_argument('--start_from_idx', type=int, default=0, help='Start from index (default: 0)')
    overfit_subparser.add_argument('--end_to_idx', type=int, default=None, help='End to index (default: None)')
    overfit_subparser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
    overfit_subparser.add_argument('--target_overfit_prob', type=float, default=None, help='Target overfit probability (default: 0.99)')
    overfit_subparser.add_argument('--target_overfit_loss', type=float, default=None, help='Target overfit loss (default: None)')
    overfit_subparser.add_argument('--num_workers', type=int, default=16, help='Number of data loading workers (default: 16)')
    overfit_subparser.add_argument('--device', type=str, default='cuda:0', help='Device to use (cuda/cpu, default: cuda:0)')

    normal_train_subparser = subparsers.add_parser('normal-train')
    normal_train_subparser.add_argument('--model', type=str, default='resnet18', help='Pretrained model from timm (default: resnet50)')
    normal_train_subparser.add_argument('--val_num_samples', type=int, default=-1, help='Number of samples to use for validation (default: -1)')
    normal_train_subparser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
    normal_train_subparser.add_argument('--epochs', type=int, default=10, help='Number of epochs (default: 10)')
    normal_train_subparser.add_argument('--num_workers', type=int, default=16, help='Number of data loading workers (default: 16)')
    normal_train_subparser.add_argument('--device', type=str, default='cuda:0', help='Device to use (cuda/cpu, default: cuda:0)')
    normal_train_subparser.add_argument('--train_size', type=int, default=32, help='Number of samples to use for training (default: -1)')
    normal_train_subparser.add_argument('--no-pretrained', action='store_true', help='Whether to use pretrained model (default: False)')


    args = parser.parse_args()
    if args.mode not in ['generate', 'overfit', 'normal-train']:
        parser.print_help()
        exit(1)
    return args

OFFICIAL_IMAGENET_MEAN = [0.485, 0.456, 0.406]
OFFICIAL_IMAGENET_STD = [0.229, 0.224, 0.225]

@dataclass
class Trainer:
    model: nn.Module
    testloader: DataLoader
    criterion: nn.Module
    device: torch.device
    val_json_dir: str
    model_name: str

def get_data_transforms(dataset_name, input_size=224):
    """Get appropriate transforms for different datasets"""
    if dataset_name == 'imagenet-val':
        transform_train = transforms.Compose([
            # no randomness
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),

            # randomness
            # transforms.RandomResizedCrop(input_size),
            # transforms.RandomHorizontalFlip(),
            # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),

            transforms.ToTensor(),
            transforms.Normalize(mean=OFFICIAL_IMAGENET_MEAN, std=OFFICIAL_IMAGENET_STD),
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize(input_size),
            # transforms.Resize(256),

            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=OFFICIAL_IMAGENET_MEAN, std=OFFICIAL_IMAGENET_STD),
        ])
    elif dataset_name == 'objectnet':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=OFFICIAL_IMAGENET_MEAN, std=OFFICIAL_IMAGENET_STD),
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=OFFICIAL_IMAGENET_MEAN, std=OFFICIAL_IMAGENET_STD),
        ])

    return transform_train, transform_test
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None, preload_and_transform_all=False):
        self.dataset = dataset
        self.transform = transform
        if preload_and_transform_all:
            self.preload_and_transform_all()

    def __len__(self):
        return len(self.dataset)
    
    def preload_and_transform_all(self):
        for sample in tqdm(self.dataset, desc='Preloading and transforming all images'):
            image_path = sample['image_path']
            image = PIL.Image.open(image_path)
            image = image.convert('RGB')
            if self.transform:
                image = self.transform(image)
            sample['image'] = image
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image_path = sample['image_path']
        label_name = sample['label']
        label = sample['label_idx']
        if 'image' not in sample:
            image = PIL.Image.open(image_path)
            image = image.convert('RGB')
            if self.transform:
                image = self.transform(image)
            sample['image'] = image
        else:
            image = sample['image']
                        
        return image, label_name, label, image_path


def my_load_dataset(dataset_name, transform_train, transform_test, batch_size, num_workers, preload_and_transform_all=False):
    """Load the specified dataset"""
    if dataset_name == 'imagenet-val':
        print("Loading ImageNet validation dataset...")
        _IMAGENET_CLASS_INDEX_URL = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
        with urllib.request.urlopen(_IMAGENET_CLASS_INDEX_URL) as r:
            class_idx = json.load(io.TextIOWrapper(r, encoding="utf-8"))
        imagenet_n0_to_idx = {v[0]: int(k) for k, v in class_idx.items()}
        imagenet_path = Path('./data/imagenet-val/val/')
        class_dirs = list(imagenet_path.iterdir())
        imagenet = []
        for class_dir in class_dirs:
            for image_path in class_dir.iterdir():
                if not image_path.is_file() or image_path.name == '.DS_Store':
                    continue
                imagenet.append({
                    'image_path': image_path,
                    'label': class_dir.name,
                    'label_idx': imagenet_n0_to_idx[class_dir.name]
                })
        
        # Split validation set 50/50 for train/test
        total_size = len(imagenet)
        random.shuffle(imagenet)
        # split_idx = total_size // 2
        
        # Create train and test splits
        # train_data = imagenet[:split_idx]
        train_data = None
        test_data = imagenet[:]
        

        print(f"ImageNet-val: {len(test_data)} test")
        trainset = None
        testset = CustomDataset(test_data, transform=transform_test, preload_and_transform_all=preload_and_transform_all)
        
    elif dataset_name == 'objectnet':
        print("Loading ObjectNet dataset...")
        objectnet = load_dataset("timm/objectnet-in1k")
        
        # Create train and test splits
        train_data = objectnet['test'].select(range(1000))
        test_data = objectnet['test'].select(range(1000, 2000))
        
        # Wrap with custom dataset class
        trainset = CustomDataset(train_data, transform=transform_train, preload_and_transform_all=preload_and_transform_all)
        testset = CustomDataset(test_data, transform=transform_test, preload_and_transform_all=preload_and_transform_all)

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x, num_workers=num_workers) if trainset is not None else None
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=lambda x: x, num_workers=num_workers)
    
    return trainloader, testloader

def create_model(model_name, pretrained=True):
    """Create a pretrained model from timm"""
    model = timm.create_model(model_name, pretrained=pretrained)
    print(f"Created model: {model_name} with pretrained={pretrained}")
    assert model.num_classes == 1000, f"Model {model_name} has {model.num_classes} classes, expected 1000"
    return model

def evaluate(trainer: Trainer):
    """Evaluate the model"""
    model = trainer.model
    testloader = trainer.testloader
    criterion = trainer.criterion
    device = trainer.device
    val_json_dir = trainer.val_json_dir
    model_name = trainer.model_name

    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    val_json = []
    asserted = False
    
    with torch.no_grad():
        pbar = tqdm(testloader, desc='Evaluating')
        for batch in pbar:
            inputs = torch.stack([s[0] for s in batch]).to(device)
            labels = [s[1] for s in batch]
            targets = torch.tensor([s[2] for s in batch]).to(device)
            image_paths = [s[3] for s in batch]
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            # softmax
            softmax = F.softmax(outputs, dim=-1)
            _, predicted = outputs.max(1)
            targets_logits = outputs[torch.arange(targets.size(0)), targets]
            predicted_logits = outputs[torch.arange(targets.size(0)), predicted]
            targets_softmax = softmax[torch.arange(targets.size(0)), targets]
            predicted_softmax = softmax[torch.arange(targets.size(0)), predicted]
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Per-class accuracy
            c = (predicted == targets).squeeze()
            for i in range(targets.size(0)):
                label = targets[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
                val_json.append({
                    'image_path': image_paths[i].as_posix(),
                    'label': labels[i],
                    'target': targets[i].item(),
                    'predicted': predicted[i].item(),
                    'is_correct': c[i].item(),
                    'targets_logits': targets_logits[i].item(),
                    'predicted_logits': predicted_logits[i].item(),
                    'targets_softmax': targets_softmax[i].item(),
                    'predicted_softmax': predicted_softmax[i].item(),
                    'logits': outputs[i].to('cpu').numpy().tolist(),
                })
            pbar.set_postfix({
                'Loss': f'{test_loss/(len(pbar)):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
            if len(val_json) > 1000 and not asserted:
                assert (100.*correct/total) > 20, f"accuracy is too low acc={100.*correct/total}"
                asserted = True
    
    test_loss /= len(testloader)
    test_acc = 100. * correct / total
    
    # Print per-class accuracy (sample of first 10 classes)
    print('\nPer-class accuracy (first 10 classes):')
    for i in range(min(10, len(class_total))):
        if class_total[i] > 0:
            print(f'Class {i}: {100 * class_correct[i] / class_total[i]:.2f}%')
    if val_json_dir:
        val_json_dir = Path(val_json_dir)
        val_json_dir.mkdir(parents=True, exist_ok=True)
        now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        with open(f'{val_json_dir}/val_{model_name}_{now}.json', 'w') as f:
            json.dump(val_json, f)
    return test_loss, test_acc

def main_generate(args=None, testloader=None):
    if args is None:
        args = parse_args()
    if args.device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(args.device)
    print(f'Using device: {device}')
    
    input_size = 224
    transform_train, transform_test = get_data_transforms(args.dataset, input_size)
    print(f"Loading {args.dataset} dataset...")
    if testloader is None:
        _, testloader = my_load_dataset(
            args.dataset, transform_train, transform_test, 
            args.batch_size, args.num_workers
        )

    # MODEL LOOP
    if args.model == 'loop':
        for model in ALL_MODELS:
            try:
                print(f"Generating {model}...")
                args.model = model
                main_generate(args, testloader)
            except Exception as e:
                print(f"Error generating {model}: {e}")
                pass
        return

    model = create_model(args.model, pretrained=True)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    
    print(f"Model: {args.model}")
    print(f"Dataset: {args.dataset}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    trainer_dataclass = Trainer(
        model=model,
        model_name=args.model,
        testloader=testloader,
        criterion=criterion,
        device=device,
        val_json_dir=args.val_json_dir
    )
    print("Generating val json...")
    evaluate(trainer_dataclass)
    return

def add_to_json(model_name, version, json_data):
    p = Path(f'data/imagenet_overfit_status/{model_name}_{version}.json')
    if p.exists():
        with open(p, 'r') as f:
            data = json.load(f)
    else:
        data = {'overfit_data': []}
    data['overfit_data'].append(json_data)
    data_json = json.dumps(data, indent=4)
    with open(p, 'w') as f:
        f.write(data_json)

def dist_in_weight_space(model1, model2):
    l1_dist = 0
    l2_dist = 0
    linf_dist = 0
    P_count = 0
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        P_count += 1
        l1_dist += (p1 - p2).norm(1)
        l2_dist += (p1 - p2).norm(2)
        linf_dist += (p1 - p2).norm(float('inf'))
    return {
        'l1_dist': l1_dist.item(),
        'l2_dist': l2_dist.item(),
        'linf_dist': linf_dist.item(),
        'l1_dist_per_param': (l1_dist / P_count).item(),
        'l2_dist_per_param': (l2_dist / P_count).item(),
        'linf_dist_per_param': (linf_dist / P_count).item(),
    }

def overfit_single_item(model, model_ref, device, dataloader, target_overfit_prob, target_overfit_loss, lr, idx, model_name):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    # optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()
    # model.eval()
    model.to(device)
    inputs, label_name, label, image_path = next(iter(dataloader))[0]
    start_time = datetime.now()
    json_data = {
        'version': '0.3',
        'idx': idx,
        'lr': lr,
        'target_overfit_prob': target_overfit_prob,
        'target_overfit_loss': target_overfit_loss,
        'start_time': start_time.strftime("%Y-%m-%d_%H-%M-%S"),
        'image_path': image_path,
        'label_name': label_name,
        'label': label,
        'model_name': model_name,
        'overfit_data': [],
    }
    inputs = inputs.unsqueeze(0).to(device)
    label = torch.tensor([label]).to(device)
    i = 0
    pbar = tqdm(desc=f'Overfitting idx={idx}')
    while True:
        outputs = model(inputs)
        probs = F.softmax(outputs, dim=-1)
        pred_idx = probs.argmax(dim=-1)
        is_correct = (pred_idx == label).item()
        target_softmax = probs[0, label].item()
        predicted_softmax = probs[0, pred_idx].item()
        loss = criterion(outputs, label)
        loss_item = loss.item()
        json_data['overfit_data'].append({
            'target_softmax': target_softmax,
            'predicted_softmax': predicted_softmax,
            'loss': loss_item,
            'is_correct': is_correct,
            'iteration': i,
        })
        if (np.isnan(target_softmax)):
            pbar.set_description(f'NAN! idx={idx}')
            pbar.set_postfix({'Loss': f'{loss_item:.4f}', 'Target Softmax': f'{target_softmax:.4f}', 'idx': idx})
            pbar.close()
            break
        if  (target_overfit_prob is not None and target_softmax >= target_overfit_prob) \
            or (target_overfit_loss is not None and loss_item <= target_overfit_loss):
            pbar.set_description(f'DONE! idx={idx}')
            pbar.set_postfix({'Loss': f'{loss_item:.4f}', 'Target Softmax': f'{target_softmax:.4f}', 'idx': idx})
            pbar.close()
            break
        i += 1
        pbar.set_postfix({'Loss': f'{loss_item:.4f}', 'Target Softmax': f'{target_softmax:.4f}', 'idx': idx})
        pbar.update(1)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    end_time = datetime.now()
    json_data['dist_in_weight_space'] = dist_in_weight_space(model, model_ref)
    json_data['duration_s'] = (end_time - start_time).total_seconds()
    json_data['end_time'] = end_time.strftime("%Y-%m-%d_%H-%M-%S")
    add_to_json(model_name, 'v11', json_data)
    
def to_image(t):
    import numpy as np
    from PIL import Image

    arr = t[0].cpu().numpy()
    # Assuming your array is normalized (0-1) or needs to be clipped
    arr = np.clip(arr, 0, 1)  # Clip to [0,1] if needed
    arr = (arr * 255).astype(np.uint8)  # Convert to 0-255 range

    # Transpose from (3, 224, 224) to (224, 224, 3) for PIL
    arr_transposed = np.transpose(arr, (1, 2, 0))
    # BGR to RGB
    arr_transposed = arr_transposed[:, :, [0, 2, 1]]
    image = Image.fromarray(arr_transposed, 'RGB')
    image.save('tmp.png')

def main_overfit(args=None):
    if args is None:
        args = parse_args()
    assert args.mode == 'overfit'
    input_size = 224
    dataset = 'imagenet-val'
    transform_train, transform_test = get_data_transforms(dataset, input_size)
    print(f"Loading {dataset} dataset...")
    json_path = Path('./data/imagenet_frozen_status/')
    json_files = list(json_path.glob(f'val_{args.model}_*.json'))
    assert len(json_files) > 0, f"No json files found for {args.model}"
    j_fn = sorted(json_files)[-1]
    with open(j_fn, 'r') as f:
        val_json = json.load(f)
    val_df = pd.DataFrame(val_json)
    val_df = val_df[~val_df['is_correct']].reset_index(drop=True)
    val_df['label_idx'] = val_df['target']
    start_from_idx = args.start_from_idx
    end_to_idx = args.end_to_idx
    if end_to_idx is not None:
        val_df = val_df.iloc[:end_to_idx]
    val_df = val_df.iloc[start_from_idx:]

    model_ref = create_model(args.model, pretrained=True).to(args.device)
    target_overfit_prob = args.target_overfit_prob
    target_overfit_loss = args.target_overfit_loss
    for idx, item in val_df.iterrows():
        dataset = CustomDataset([item], transform=transform_test)
        dataloader = DataLoader(dataset, batch_size=1, collate_fn=lambda x: x)
        model = copy.deepcopy(model_ref)
        overfit_single_item(model, model_ref, args.device, dataloader, target_overfit_prob, target_overfit_loss, args.lr, idx, args.model)

def train_single_epoch(model, train_loader, optimizer, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    B = len(train_loader)
    pbar = tqdm(enumerate(train_loader), total=B, desc='Training')
    train_data = {
        'loss': [],
        'final_loss': None,
        'image_count': 0,
    }
    for i, batch in pbar:
        assert len(batch) == 1, f'Batch isnt 1? len(batch)={len(batch)}'
        inputs = batch[0]['images']
        labels = batch[0]['labels'].squeeze().to(torch.long)
        inputs = inputs.permute(0, 3, 1, 2)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.item()
        train_data['loss'].append(loss.item())
        train_data['image_count'] += inputs.shape[0]
        pbar.set_postfix({'Training Loss': f'{running_loss / (i+1):.4f}'})
    train_data['final_loss'] = running_loss / B
    return train_data

def evaluate_model(model, val_loader, labels_dict, device):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='none')
    running_loss = 0.0
    B = len(val_loader)
    pbar = tqdm(enumerate(val_loader), total=B, desc='Evaluating')
    val_data = {
        'loss': [],
        'target_softmax': {},
        'idx_losses': {},
        'is_correct': {},
        'image_count': 0,
    }
    total_correct = 0
    total_samples = 0
    for i, batch in pbar:
        assert len(batch) == 1, f'Batch isnt 1? len(batch)={len(batch)}'
        inputs = batch[0]['images']
        idx = batch[0]['labels'].squeeze().tolist()
        labels = torch.tensor([labels_dict[i] for i in idx]).to(torch.long).to(device)
        inputs = inputs.permute(0, 3, 1, 2)
        outputs = model(inputs)
        softmax = F.softmax(outputs, dim=-1)
        target_softmax = softmax[torch.arange(labels.size(0)), labels]
        loss = criterion(outputs, labels)
        val_data['target_softmax'].update({idx[i]: target_softmax[i].item() for i in range(len(labels))})
        val_data['idx_losses'].update({idx[i]: loss[i].item() for i in range(len(labels))})
        val_data['is_correct'].update({idx[i]: (labels[i] == outputs.argmax(dim=-1)[i]).item() for i in range(len(labels))})
        running_loss += loss.mean().item()
        total_correct += (labels == outputs.argmax(dim=-1)).sum().item()
        total_samples += labels.size(0)
        pbar.set_postfix({'Evaluation Loss': f'{running_loss / (i+1):.4f} Acc: {100 * total_correct / total_samples:.2f}%'})
        val_data['image_count'] += inputs.shape[0]
    if len(val_data['target_softmax']) != 50_000:
        print(f'Missing {50_000 - len(val_data["target_softmax"])} indices from the 50k. might be expected.')
    val_data['final_loss'] = running_loss / B
    val_data['final_acc'] = 100 * total_correct / total_samples
    return val_data

def main_normal_train(args=None):
    if args is None:
          args = parse_args()
    assert args.mode == 'normal-train'
    print(f'Args: {args}')

    model_name = args.model if not args.no_pretrained else None
    train_loader, val_loader, labels_dict = get_dataloaders(model_name=model_name, val_num_samples=args.val_num_samples, train_size=args.train_size)
    epochs = args.epochs
    model = create_model(args.model, pretrained=not args.no_pretrained)
    model = model.to(args.device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    json_data = {
        'version': 'sep22_v2',
        'lr': args.lr,
        'start_time': datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
        'model_name': args.model,
        'train_data': [],
        'val_data': [],
    }
    for epoch in range(epochs):
        train_data = train_single_epoch(model, train_loader, optimizer, args.device)
        val_data = evaluate_model(model, val_loader, labels_dict, args.device)
        json_data['train_data'].append(train_data)
        json_data['val_data'].append(val_data)
        train_loss = train_data['final_loss']
        val_loss = val_data['final_loss']
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        with open(f'data/imagenet_normal_train_status/{args.model}_{json_data["version"]}.json', 'w') as f:
            json.dump(json_data, f)
    pass
    pass
    pass


def main(args=None):
    if args is None:
        args = parse_args()
    if args.mode == 'generate':
        main_generate(args)
    elif args.mode == 'overfit':
        main_overfit(args)
    elif args.mode == 'normal-train':
        main_normal_train(args)
    else:
        raise ValueError(f"Unknown mode: {args.mode}")

if __name__ == '__main__':
    main()
