import argparse
import os
import shutil
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, RandomSampler
from torchvision import transforms, datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import timm

# ===== Import custom optimizers =====
from zen import ZenGrad, ZenGrad_M
from lion_pytorch import Lion
from adabelief_pytorch import AdaBelief

# ===== Utility Classes =====
class AverageMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def mkdir_p(path):
    os.makedirs(path, exist_ok=True)

def savefig(fname):
    plt.savefig(fname)

class Logger:
    def __init__(self, fpath, title=None, resume=False):
        self.file = open(fpath, 'a' if resume else 'w')
        if title: self.file.write(f"# {title}\n")
        self.names = []
    def set_names(self, names):
        self.names = names
        self.file.write("\t".join(names) + "\n")
    def append(self, values):
        self.file.write("\t".join(map(str, values)) + "\n")
        self.file.flush()
    def close(self):
        self.file.close()

# ===== Argument Parser =====
parser = argparse.ArgumentParser(description='ViT-S/16 Training')
parser.add_argument('--data', metavar='DIR', help='path to ImageNet dataset')
parser.add_argument('--steps', default=100000, type=int, help='total training steps')
parser.add_argument('--train-batch', default=256, type=int)
parser.add_argument('--test-batch', default=200, type=int)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--weight-decay', default=1e-4, type=float)
parser.add_argument('--checkpoint', default='checkpoint', type=str)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--gpu-id', default='0', type=str)
parser.add_argument('--manualSeed', type=int)
parser.add_argument('-e', '--evaluate', action='store_true')
parser.add_argument('--optim', default='adamw', type=str,
                    choices=['adamw', 'lion', 'adabelief', 'nadam', 'zengrad', 'm-zengrad'],
                    help='optimizer choice')
args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

# ===== Workspace & CUDA setup =====
workspace_dir = "/workspace"
mkdir_p(workspace_dir)
args.checkpoint = os.path.join(workspace_dir, args.checkpoint)
mkdir_p(args.checkpoint)

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
use_cuda = torch.cuda.is_available()

if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)

best_acc = 0
metrics_log = []

# ===== Main =====
def main():
    global best_acc

    # ===== Data transforms =====
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    # ===== ImageNet Dataset =====
    train_dataset = datasets.ImageFolder(os.path.join(args.data, 'train'), transform=train_transform)
    val_dataset = datasets.ImageFolder(os.path.join(args.data, 'val'), transform=val_transform)

    # Continuous random sampling for 100K steps
    train_sampler = RandomSampler(train_dataset, replacement=True, num_samples=args.steps * args.train_batch)
    train_loader = DataLoader(train_dataset, batch_size=args.train_batch, sampler=train_sampler, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=8, pin_memory=True)

    # ===== Model =====
    print("=> Creating model 'vit_small_patch16_224'")
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=1000)
    if use_cuda:
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
    print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1e6))

    # ===== Loss & Optimizer =====
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1).cuda()
    optimizer = get_optimizer(args.optim, model, args.lr, args.weight_decay)

    # --- Cosine Annealing Scheduler ---
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.steps, eta_min=1e-6)

    scaler = torch.cuda.amp.GradScaler()

    # ===== Logger =====
    title = 'ImageNet-ViT-S16'
    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    # ===== Step-based training loop =====
    step = 0
    while step < args.steps:
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, scaler, step, use_cuda)
        scheduler.step()  

        current_lr = scheduler.get_last_lr()[0]

        if step % 1000 == 0:  # every 1K steps evaluate
            test_loss, test_acc = test(val_loader, model, criterion, step, use_cuda)
            logger.append([current_lr, train_loss, test_loss, train_acc, test_acc])

            metrics_log.append({
                "step": step,
                "learning_rate": current_lr,
                "train_loss": train_loss,
                "val_loss": test_loss,
                "train_acc": float(train_acc),
                "val_acc": float(test_acc)
            })

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)
            save_checkpoint({
                'step': step,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }, is_best, checkpoint=args.checkpoint)
        step += 1

    df = pd.DataFrame(metrics_log)
    excel_path = os.path.join(args.checkpoint, "training_metrics.xlsx")
    df.to_excel(excel_path, index=False)
    print(f"Metrics saved to {excel_path}")
    logger.close()
    savefig(os.path.join(args.checkpoint, 'log.eps'))
    print('Best acc:', best_acc)

# ===== Optimizer Selector =====
def get_optimizer(opt_name, model, lr, weight_decay):
    if opt_name == 'adamw':
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'lion':
        return Lion(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'adabelief':
        return AdaBelief(model.parameters(), lr=lr, weight_decay=weight_decay, rectify=True)
    elif opt_name == 'nadam':
        return torch.optim.NAdam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'zengrad':
        return ZenGrad(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'm-zengrad':
        return ZenGrad_M(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {opt_name}")

# ===== Training =====
def train(train_loader, model, criterion, optimizer, scaler, step, use_cuda):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for inputs, targets in tqdm(train_loader, desc=f"Train Step {step}/{args.steps}"):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        prec1, prec5 = accuracy(outputs.detach(), targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1[0].item(), inputs.size(0))
        top5.update(prec5[0].item(), inputs.size(0))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        break  # one batch = one step
    return losses.avg, top1.avg

def test(val_loader, model, criterion, step, use_cuda):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc=f"Valid Step {step}"):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0].item(), inputs.size(0))
            top5.update(prec5[0].item(), inputs.size(0))

    return losses.avg, top1.avg

# ===== Checkpoint =====
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filename = os.path.join(checkpoint, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(checkpoint, 'model_best.pth.tar'))

if __name__ == '__main__':
    main()
