
'''Train CIFAR10 with PyTorch.'''
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from mup.coord_check import get_coord_data, plot_coord_data
from mup import MuAdam, MuSGD, get_shapes, make_base_shapes, set_base_shapes
import random
import resnet


save_path = "" # Put name here 


def coord_check(mup, lr, optimizer, nsteps, arch, base_shapes, nseeds, device='cuda', plotdir='', legend=False):

    optimizer = optimizer.replace('mu', '')

    def gen(w, standparam=False):
        def f():
            model = getattr(resnet, arch)(wm=w).to(device)
            if standparam:
                set_base_shapes(model, None)
            else:
                set_base_shapes(model, base_shapes)
            return model
        return f

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(
        root='../dataset', train=True, download=True, transform=transform_train)
    dataloader = torch.utils.data.DataLoader(
        trainset, batch_size=1, shuffle=False)

    widths = 2**np.arange(-2., 2)
    models = {w: gen(w, standparam=not mup) for w in widths}
    df = get_coord_data(models, dataloader, mup=mup, lr=lr, optimizer=optimizer, nseeds=nseeds, nsteps=nsteps)

    prm = 'μP' if mup else 'SP'
    #plot_coord_data(df, legend=legend,
    #    save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_coord.png'),
    #    suptitle=f'{prm} {arch} {optimizer} lr={lr} nseeds={nseeds}',
    #    face_color='xkcd:light grey' if not mup else None)


# Training
def train(epoch, nets, metrics):
    #from utils import progress_bar
    print('\nEpoch: %d' % epoch)
    for e, net in enumerate(nets):
        net.train()

    E = len(nets)
    train_loss = 0
    ens_train_loss = 0
    correct = 0
    total = 0
    correct_ens = 0
    total_ens = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        mean_logit = torch.zeros((targets.shape[0],10)).to(device)
        for e, net in enumerate(nets):
            optimizers[e].zero_grad()
            outputs = net(inputs)
            mean_logit += 1.0/E * outputs
            loss = criterion(outputs, targets)
            loss.backward()
            optimizers[e].step()
            train_loss += loss.item() / len(nets)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        ens_train_loss += criterion(mean_logit, targets).item()
        total_ens += targets.size(0)
        _,predict_ens = mean_logit.max(1)
        correct_ens += predict_ens.eq(targets).sum().item()
        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Ens Loss: %.3f | Acc: %.3f%% (%d/%d) | Ens Acc: %.3f%% (%d/%d)'
        #             % (train_loss/(batch_idx+1), ens_train_loss/(batch_idx+1), 100.*correct/total, correct, total, 100.*correct_ens/total_ens, correct_ens, total_ens))
    metrics['train_loss'] += [train_loss/(batch_idx+1)]
    metrics['ens_train_loss'] += [ens_train_loss/(batch_idx+1)]
    metrics['train_acc'] += [100.*correct/total]
    metrics['ens_train_acc'] += [100.*correct_ens/total_ens] 
    return metrics

def test(epoch, nets, metrics):
    #from utils import progress_bar
    global best_acc
    for e,net in enumerate(nets):
        net.eval()
    test_loss = 0
    ens_test_loss = 0
    correct = 0
    total = 0
    correct_ens = 0
    total_ens = 0
    all_logits_correct = [ [] for e in range(len(nets))]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mean_logit = torch.zeros((targets.shape[0],10)).to(device)
            for e, net in enumerate(nets):
                outputs = net(inputs)
                # get the logit 
                correct_logit = outputs[:,targets]
                all_logits_correct[e] += [ correct_logit ] 
                mean_logit += 1/len(nets) * outputs
                loss = criterion(outputs, targets)
                test_loss += loss.item()/len(nets)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            ens_test_loss += criterion(mean_logit, targets).item()
            total_ens += targets.size(0)
            _,predict_ens = mean_logit.max(1)
            correct_ens += predict_ens.eq(targets).sum().item()
            #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Ens Loss: %.3f | Acc: %.3f%% (%d/%d) | Ens Acc: %.3f%% (%d/%d)'
            #         % (test_loss/(batch_idx+1), ens_test_loss/(batch_idx+1), 100.*correct/total, correct, total, 100.*correct_ens/total_ens, correct_ens, total_ens))
            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) | Acc: %.3f%% (%d/%d) | Ens Acc: %.3f%% (%d/%d)'
            #             % (test_loss/(batch_idx+1),ens_test_loss/(batch_idx+1),100.*correct/total, correct, total, 100.*correct_ens/total_ens, correct_ens, total_ens))
    metrics['test_loss'] += [test_loss/(batch_idx+1)]
    metrics['ens_test_loss'] += [ens_test_loss/(batch_idx+1)]
    metrics['test_acc'] += [100.*correct/total]
    metrics['ens_test_acc'] += [100.*correct_ens/total_ens]
    # Save checkpoint.
    acc = 100.*correct/total
    
    logit_state = {'logits_correct': all_logits_correct}
    torch.save(logit_state, save_path + f'/logits_correct_N_{args.width_mult}_epoch_{epoch}_.pth')
    
    if epoch % 5 == 0:
        print('Saving..')
        state = {
            'nets': [net.state_dict() for net in nets],
            'metrics': metrics,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, save_path + f'/ckpt_N_{args.width_mult}_epoch_{epoch}_.pth')
        
        best_acc = acc

    return metrics

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description=''
    '''
    PyTorch CIFAR10 Training, with μP.

    To save base shapes info, run e.g.

        python main.py --save_base_shapes resnet18.bsh --width_mult 1

    To train using MuAdam (or MuSGD), run

        python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd}

    To test coords, run

        python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check

        python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check

    If you don't specify a base shape file, then you are using standard parametrization, e.g.

        python main.py --width_mult 2 --optimizer {muadam,musgd}

    Here muadam (resp. musgd) would have the same result as adam (resp. sgd).

    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--arch', type=str, default='resnet18')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam'])
    parser.add_argument('--epochs', type=int, default=21)
    parser.add_argument('--num_ens', type=int, default=8)
    parser.add_argument('--width_mult', type=float, default=1)
    parser.add_argument('--save_base_shapes', type=str, default='',
                        help='file location to save base shapes at')
    parser.add_argument('--load_base_shapes', type=str, default='',
                        help='file location to load base shapes from')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--test_batch_size', type=int, default=128)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--test_num_workers', type=int, default=0)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--coord_check', action='store_true',
                        help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
    parser.add_argument('--coord_check_nsteps', type=int, default=3,
                        help='Do coord check with this many steps.')
    parser.add_argument('--coord_check_nseeds', type=int, default=1,
                        help='number of seeds for coord check')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    E = args.num_ens
    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    # Data
    if not args.save_base_shapes:
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        g = torch.Generator()
        g.manual_seed(args.seed)
        
        def seed_worker(worker_id):
            worker_seed = torch.initial_seed() % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)
        
        trainset = torchvision.datasets.CIFAR10(
            root='../dataset', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, generator=g, worker_init_fn = seed_worker)

        testset = torchvision.datasets.CIFAR10(
            root='../dataset', train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers, generator=g, worker_init_fn = seed_worker)

        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck')

    if args.coord_check:
        print('testing parametrization')
        import os
        os.makedirs('coord_checks', exist_ok=True)
        plotdir = 'coord_checks'
        coord_check(mup=True,
            lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, nseeds=args.coord_check_nseeds, device=device, plotdir=plotdir, legend=False)
        coord_check(mup=False,
            lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, nseeds=args.coord_check_nseeds, device=device,plotdir=plotdir, legend=False)
        import sys; sys.exit()


    # Model
    print('==> Building model..')
    nets = []
    for e in range(E):
        torch.manual_seed(e)
        nets += [ getattr(resnet, args.arch)(wm=args.width_mult) ]

    if args.save_base_shapes:
        print(f'saving base shapes at {args.save_base_shapes}')
        base_shapes = get_shapes(net)
        delta_shapes = get_shapes(getattr(resnet, args.arch)(wm=args.width_mult/2))
        make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
        # save_shapes(net, args.save_base_shapes)
        print('done and exit')
        import sys; sys.exit()


    nets = [ net.to(device) for net in nets ]

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    if args.load_base_shapes:
        print(f'loading base shapes from {args.load_base_shapes}')
        for e, net in enumerate(nets):
            set_base_shapes(net, args.load_base_shapes)
        print('done')
    else:
        print(f'using standard parametrization')
        set_base_shapes(net, None)
        print('done')

    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(save_path + 'ckpt.pth')
        nets = checkpoint['nets']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'musgd':

        optimizers = [ MuSGD(net.parameters(), lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay) for net in nets ]
    elif args.optimizer == 'muadam':
        optimizer = MuAdam(net.parameters(), lr=args.lr)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise ValueError()
    schedulers = [torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) for optimizer in optimizers]

    metrics = {'train_loss': [], 'ens_train_loss': [], 'test_loss': [], 'ens_test_loss': [], 'test_acc': [], 'ens_test_acc': [], 'train_acc': [], 'ens_train_acc': []}
    for epoch in range(start_epoch, start_epoch+args.epochs):
        metrics = test(epoch, nets, metrics)
        metrics = train(epoch, nets,metrics)
        for scheduler in schedulers:
            scheduler.step()

