
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from mup.coord_check import get_coord_data, plot_coord_data
from mup import MuAdam, MuSGD, get_shapes, make_base_shapes, set_base_shapes
#import matplotlib.pyplot as plt
import preact_resnet


def test(epoch, nets, metrics):
    from utils import progress_bar
    global best_acc
    for e,net in enumerate(nets):
        net.eval()
    test_loss = 0
    ens_test_loss = 0
    correct = 0
    total = 0
    correct_ens = 0
    total_ens = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            mean_logit = torch.zeros((targets.shape[0],10)).to(device)
            for e, net in enumerate(nets):
                outputs = net(inputs)
                mean_logit += 1/len(nets) * outputs
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            ens_test_loss += criterion(mean_logit, targets).item()
            total_ens += targets.size(0)
            _,predict_ens = mean_logit.max(1)
            correct_ens += predict_ens.eq(targets).sum().item()
            progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Ens Loss: %.3f | Acc: %.3f%% (%d/%d) | Ens Acc: %.3f%% (%d/%d)'
                     % (test_loss/(batch_idx+1), ens_test_loss/(batch_idx+1), 100.*correct/total, correct, total, 100.*correct_ens/total_ens, correct_ens, total_ens))
            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) | Acc: %.3f%% (%d/%d) | Ens Acc: %.3f%% (%d/%d)'
            #             % (test_loss/(batch_idx+1),ens_test_loss/(batch_idx+1),100.*correct/total, correct, total, 100.*correct_ens/total_ens, correct_ens, total_ens))
    metrics['test_loss'] += [test_loss/(batch_idx+1)]
    metrics['ens_test_loss'] += [ens_test_loss/(batch_idx+1)]
    metrics['test_acc'] += [100.*correct/total]
    metrics['ens_test_acc'] += [100.*correct_ens/total_ens]
    # Save checkpoint.
    acc = 100.*correct/total
    if epoch % 5 == 0:
        print('Saving..')
        state = {
            'nets': [net.state_dict() for net in nets],
            'metrics': metrics,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        #torch.save(state, save_path + f'/ckpt_N_{args.width_mult}_epoch_{epoch}_.pth')
        best_acc = acc

    return metrics


save_path = "" # Put name here

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description=''
    '''
    PyTorch CIFAR10 Training, with μP.

    To save base shapes info, run e.g.

        python main.py --save_base_shapes resnet18.bsh --width_mult 1

    To train using MuAdam (or MuSGD), run

        python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd}

    To test coords, run

        python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check

        python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check

    If you don't specify a base shape file, then you are using standard parametrization, e.g.

        python main.py --width_mult 2 --optimizer {muadam,musgd}

    Here muadam (resp. musgd) would have the same result as adam (resp. sgd).

    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--arch', type=str, default='resnet18')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam'])
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--num_ens', type=int, default=4)
    parser.add_argument('--width_mult', type=float, default=1)
    parser.add_argument('--save_base_shapes', type=str, default='',
                        help='file location to save base shapes at')
    parser.add_argument('--load_base_shapes', type=str, default='',
                        help='file location to load base shapes from')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--test_batch_size', type=int, default=128)
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--test_num_workers', type=int, default=2)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--coord_check', action='store_true',
                        help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
    parser.add_argument('--coord_check_nsteps', type=int, default=3,
                        help='Do coord check with this many steps.')
    parser.add_argument('--coord_check_nseeds', type=int, default=1,
                        help='number of seeds for coord check')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if not args.save_base_shapes:
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(
            root='../dataset', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

        testset = torchvision.datasets.CIFAR10(
            root='../dataset', train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers)

        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck')


    epoch = args.epochs
    chkpt = torch.load(save_path + f'/ckpt_N_{args.width_mult}_epoch_{epoch}_.pth')
    E = len(chkpt['nets'])
    models = [ getattr(preact_resnet, args.arch)(wm=args.width_mult) for e in range(E) ]
    for e, net in enumerate(models):
        set_base_shapes(net, args.load_base_shapes)
    for e, model in enumerate(models):
        model.load_state_dict(chkpt['nets'][e])

    metrics = chkpt['metrics']
    print(metrics['train_loss'])
    print(metrics['ens_train_loss'])
    print(metrics['test_loss'])
    print(metrics['ens_test_loss'])
    result_path = './saved_losses_kernels_acts'
    import numpy as np
    np.save(result_path + f'/test_acc_N_{args.width_mult}_epoch_{epoch}.npy', torch.tensor(metrics['test_acc']).numpy())
    np.save(result_path + f'/ens_test_acc_N_{args.width_mult}_epoch_{epoch}.npy',torch.tensor(metrics['ens_test_acc']).numpy())

    logit_preds = [ [] for e in range(E)]
    for t in range(args.epochs):
        logits = torch.load(save_path + f'/logits_correct_N_{args.width_mult}_epoch_{t}_.pth')['logits_correct']
        for e in range(E):
            logit_preds[e] += [ logits[e][0][0].cpu().numpy() ]
    
    np.save(result_path +f'/corr_logits_N_{args.width_mult}_epoch_{epoch}.npy', np.array(logit_preds))
    
    #plt.figure()
    #plt.plot(metrics['test_loss'])
    #plt.plot(metrics['ens_test_loss'])
    #plt.tight_layout()
    #plt.savefig('figures/test_loss_example.pdf')
    #plt.show()
    for e,model in enumerate(models):
        model.to(device)
    #metrics = test(epoch, models, metrics)
    #print("got metrics!")


    
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    all_avg_K = [torch.zeros(args.test_batch_size,args.test_batch_size).to(device) for l in range(4) ]
    all_phi = [[] for l in range(4) ]
    for e, model in enumerate(models):
        model.layer1.register_forward_hook(get_activation('layer1'))
        model.layer2.register_forward_hook(get_activation('layer2'))
        model.layer3.register_forward_hook(get_activation('layer3'))
        model.layer4.register_forward_hook(get_activation('layer4')) 
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                out = model(inputs)
                for l in range(4):
                    phi = activation[f'layer{l+1}']
                    print(activation[f'layer{l+1}'].shape)
                    K_le = torch.einsum('ijkl,mjkl->im', phi,phi) / phi.shape[1]/phi.shape[2]/phi.shape[3]
                    all_avg_K[l] += 1/len(models) * K_le
                    all_phi[l] += [phi[:2,:,0,0].cpu().numpy()]
                break
    np.save(result_path + f'/target_labels.npy', targets.cpu().numpy())
    for l in range(4):
        np.save(result_path + f'/kernels_N_{args.width_mult}_l_{l+1}_epoch_{epoch}.npy',all_avg_K[l].cpu().numpy() )           
        np.save(result_path + f'/kernels_N_{args.width_mult}_l_{l+1}_epoch_{epoch}.npy',all_avg_K[l].cpu().numpy() )
        phi_concat = np.concatenate(np.array(all_phi[l]),axis=1)
        np.save(result_path + f'/activation_N_{args.width_mult}_l_{l+1}_epoch_{epoch}.npy',phi_concat)

