import argparse
import os
import shutil
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

# Import custom optimizers
from zen import ZenGrad, ZenGrad_M
from lion_pytorch import Lion  # make sure lion-pytorch is installed
from adabelief_pytorch import AdaBelief

# ===== Simple utilities =====
class AverageMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def mkdir_p(path):
    os.makedirs(path, exist_ok=True)

def savefig(fname):
    plt.savefig(fname)

class Logger:
    def __init__(self, fpath, title=None, resume=False):
        self.file = open(fpath, 'a' if resume else 'w')
        if title: self.file.write(f"# {title}\n")
        self.names = []
    def set_names(self, names):
        self.names = names
        self.file.write("\t".join(names) + "\n")
    def append(self, values):
        self.file.write("\t".join(map(str, values)) + "\n")
        self.file.flush()
    def close(self):
        self.file.close()

# ===== Argument parser =====
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training with Multiple Optimizers')
parser.add_argument('--data', metavar='DIR', help='path to ImageNet dataset')
parser.add_argument('--epochs', default=90, type=int)
parser.add_argument('--train-batch', default=256, type=int)
parser.add_argument('--test-batch', default=200, type=int)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--weight-decay', default=1e-4, type=float)
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60])
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--checkpoint', default='checkpoint', type=str)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
                    choices=sorted(name for name in models.__dict__
                                   if name.islower() and not name.startswith("__") and callable(models.__dict__[name])))
parser.add_argument('--gpu-id', default='0', type=str)
parser.add_argument('--manualSeed', type=int)
parser.add_argument('-e', '--evaluate', action='store_true')
parser.add_argument('--optim', default='adamw', type=str,
                    choices=['adamw', 'lion', 'adabelief', 'nadam', 'zengrad', 'm-zengrad'],
                    help='optimizer choice')
args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

# ===== Workspace & CUDA setup =====
workspace_dir = "/workspace"
mkdir_p(workspace_dir)
args.checkpoint = os.path.join(workspace_dir, args.checkpoint)
mkdir_p(args.checkpoint)

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
use_cuda = torch.cuda.is_available()

if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)

best_acc = 0
metrics_log = []

# ===== Main =====
def main():
    global best_acc
    start_epoch = 0

    # ===== Transforms =====
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    # ===== ImageNet Dataset =====
    train_dataset = datasets.ImageFolder(os.path.join(args.data, 'train'), transform=train_transform)
    val_dataset = datasets.ImageFolder(os.path.join(args.data, 'val'), transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=8, pin_memory=True)

    # ===== Model =====
    print(f"=> creating model '{args.arch}'")
    model = models.__dict__[args.arch](weights=None, num_classes=1000)
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1e6))

    # ===== Loss & optimizer =====
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1).cuda()
    optimizer = get_optimizer(args.optim, model, args.lr, args.weight_decay)

    scaler = torch.cuda.amp.GradScaler()

    # ===== Resume =====
    title = 'ImageNet-' + args.arch
    global state
    if args.resume:
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(args.resume)
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda)
        print(f'Test Loss: {test_loss:.8f}, Test Acc: {test_acc:.2f}')
        return

    # ===== Training loop =====
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        print(f'\nEpoch: [{epoch+1} | {args.epochs}] LR: {state["lr"]:.6f}')
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, scaler, epoch, use_cuda)
        test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        metrics_log.append({
            "epoch": epoch + 1,
            "learning_rate": state['lr'],
            "train_loss": train_loss,
            "val_loss": test_loss,
            "train_acc": float(train_acc),
            "val_acc": float(test_acc)
        })

        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'acc': test_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint=args.checkpoint)

    df = pd.DataFrame(metrics_log)
    excel_path = os.path.join(args.checkpoint, "training_metrics.xlsx")
    df.to_excel(excel_path, index=False)
    print(f"Metrics saved to {excel_path}")
    logger.close()
    savefig(os.path.join(args.checkpoint, 'log.eps'))
    print('Best acc:', best_acc)

# ===== Optimizer selector =====
def get_optimizer(opt_name, model, lr, weight_decay):
    if opt_name == 'adamw':
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'lion':
        return Lion(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'adabelief':
        return AdaBelief(model.parameters(), lr=lr, weight_decay=weight_decay, rectify=True)
    elif opt_name == 'nadam':
        return torch.optim.NAdam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'zengrad':
        return ZenGrad(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt_name == 'm-zengrad':
        return ZenGrad_M(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {opt_name}")

# ===== Training =====
def train(train_loader, model, criterion, optimizer, scaler, epoch, use_cuda):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for inputs, targets in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        prec1, prec5 = accuracy(outputs.detach(), targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1[0].item(), inputs.size(0))
        top5.update(prec5[0].item(), inputs.size(0))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return losses.avg, top1.avg

def test(val_loader, model, criterion, epoch, use_cuda):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc=f"Valid Epoch {epoch+1}"):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0].item(), inputs.size(0))
            top5.update(prec5[0].item(), inputs.size(0))

    return losses.avg, top1.avg

# ===== Learning rate adjustment =====
def adjust_learning_rate(optimizer, epoch):
    global state
    lr = args.lr
    for milestone in args.schedule:
        if epoch >= milestone:
            lr *= args.gamma
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    state['lr'] = lr

# ===== Checkpoint =====
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filename = os.path.join(checkpoint, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(checkpoint, 'model_best.pth.tar'))

if __name__ == '__main__':
    main()
