import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from mup import get_shapes, make_base_shapes, set_base_shapes
import random
import wandb
from utils import load_data, get_model, get_optimizers, process_args, coord_check, coord_check_depth
from train_test import train, test
import torch.distributed as dist
from functools import partial
import transformers



def get_run_name(args):
    return "model_{}/dataset_{}/epoch_{}/lr_{:.4f}/seed_{}/momentum_{}/batch_size_{}/res_scaling_{:.4f}/width_mult_{}/depth_{}/parametr_{}/depth_scale_first_{}/skip_scaling_{}/beta_{}/gamma_zero_{}/weight_decay_{}/norm_{}".format(
        args.arch, args.dataset, args.epochs, args.lr, args.seed, args.momentum, args.batch_size, args.res_scaling, args.width_mult, args.depth,
        args.parametr, args.depth_scale_first, args.skip_scaling, args.beta, args.gamma_zero, args.weight_decay, args.norm)
    
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description=''
    '''
    PyTorch CIFAR10 Training, with μP.

    To save base shapes info, run e.g.

        python main.py --save_base_shapes resnet18.bsh --width_mult 1

    To train using MuAdam (or MuSGD), run

        python main.py --width_mult 2 --load_base_shapes resnet18.bsh --optimizer {muadam,musgd}

    To test coords, run

        python main.py --load_base_shapes resnet18.bsh --optimizer sgd --lr 0.1 --coord_check

        python main.py --load_base_shapes resnet18.bsh --optimizer adam --lr 0.001 --coord_check

    If you don't specify a base shape file, then you are using standard parametrization, e.g.

        python main.py --width_mult 2 --optimizer {muadam,musgd}

    Here muadam (resp. musgd) would have the same result as adam (resp. sgd).

    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--lr', default=5.0, type=float, help='learning rate')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--resume_epoch', type=int, default=-1)
    parser.add_argument('--arch', type=str, default='resnet18')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'adam', 'musgd', 'muadam', 'musgd_depth'])
    parser.add_argument('--epochs', type=int, default=21)
    parser.add_argument('--num_ens', type=int, default=1)
    parser.add_argument('--width_mult', type=float, default=2.0)
    parser.add_argument('--save_base_shapes', type=str, default='',
                        help='file location to save base shapes at')
    parser.add_argument('--load_base_shapes', type=str, default='',
                        help='file location to load base shapes from')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--test_batch_size', type=int, default=64)
    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--num_workers', type=int, default=16)
    parser.add_argument('--test_num_workers', type=int, default=16)
    parser.add_argument('--momentum', type=float, default=0.0)
    parser.add_argument('--coord_check', action='store_true',
                        help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
    parser.add_argument('--coord_check_nsteps', type=int, default=3,
                        help='Do coord check with this many steps.')
    parser.add_argument('--coord_check_nseeds', type=int, default=10,
                        help='number of seeds for coord check')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    parser.add_argument('--save_dir', type=str, default='test/',
                    help='file location to save results')
    parser.add_argument('--res_scaling', type=str, default='none')
    parser.add_argument('--data_path', type=str, default='./data')
    parser.add_argument('--parametr', type=str, default="mup")
    parser.add_argument('--depth_scale_first', type=str, default='none')
    parser.add_argument('--dataset', type=str, default='imgnet')
    parser.add_argument('--depth_mult', type=int, default=1)
    parser.add_argument('--skip_scaling', type=float, default=1,
                         help='set to zero to use an MLP without skip connections')
    parser.add_argument('--coord_check_depth', action='store_true',
                        help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
    parser.add_argument('--beta', type=float, default=1,
                         help='scaling factor for the residual branch. To use together with res_scaling parameter')
    parser.add_argument('--gamma_zero', type=float, default=1,
                         help='controls the amount of feature learning.')
    parser.add_argument('--norm', type=str, default='none',
                         help='normalization layer')
    
    parser.add_argument('--multiprocessing', action='store_true',
                         help='distributed computation for multiple GPU (single node)')
    parser.add_argument('--schedule', action='store_true',help ='cosine anneal schedule')
    parser.add_argument('--warmup_steps', type=int, default=0)

    args = parser.parse_args()
    
    c = 0
    if args.lr == -1:
        lrs = np.logspace(-3, -0.5, num=10) if args.arch != "vit" else np.logspace(-4, -2, num=10)
        c += 1
    else:
        lrs = [args.lr]
        
    if args.batch_size == -1:
        batch_sizes = np.logspace(4, 9, num=6, base=2)
        c += 1
    else:
        batch_sizes = [args.batch_size]
        
    if args.momentum == -1:
        momenta = np.linspace(0.1, 1, num=10)
        c += 1
    else:
        momenta = [args.momentum]
        
    if args.width_mult == -1:
        width_mults = np.logspace(1, 4, num=4, base=2)
    else:
        width_mults = [args.width_mult]
        
    if args.beta == -1:
        betas = [0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]
    else:
        betas = [args.beta]
    
    if args.gamma_zero == -1:
        gamma_zeros = np.linspace(0.3, 1, num=8)
    else:
        gamma_zeros = [args.gamma_zero]

    if args.weight_decay == -1:
        lambdas = [1e-5,1e-4,1e-3,1e-2,1e-1]
    else:
        lambdas = [args.weight_decay]
    
    if c > 1:
        print(f"Warning: performing hyperparameter search over {c} parameters. It might take a while")
        
    args = process_args(args)
    
    max_updates = -1
    
    for gamma_zero in gamma_zeros:
        for beta in betas:
            for width_mult in width_mults:
                for lr in lrs:
                    for batch_size in batch_sizes:
                        for momentum in momenta:
                            for lamb in lambdas:
                                if len(lrs) == 1 and len(batch_sizes) > 1:
                                    print("Setting learning rate based on batch size")
                                    args.lr = lr * batch_size / batch_sizes[0] * gamma_zero**2
                                else:
                                    args.lr = lr * gamma_zero**2
                                args.batch_size = int(batch_size)
                                args.momentum = momentum
                                args.width_mult = int(width_mult)
                                args.beta = beta
                                args.gamma_zero = gamma_zero
                                args.weight_decay = lamb

                                print(f"Proccesing hyperparmeters: learning rate {args.lr}, batch size {args.batch_size}, momentum {args.momentum}, \
                                      width_mult {args.width_mult}, beta {args.beta}, gamma_zero {args.gamma_zero} weight_decay {args.weight_decay}")

                                ## TODO: CODE THIS BETTER
                                if args.dataset == "imgnet":
                                    num_classes = 1000 
                                elif args.dataset == "tiny_imgnet":
                                    num_classes = 200
                                elif args.dataset == "cifar10":
                                    num_classes = 10
                                else:
                                    raise ValueError()
                                # logs
                                run_name = get_run_name(args)

                                if not os.path.isdir(args.save_dir):
                                    os.mkdir(args.save_dir)
                                args.save_path = os.path.join(args.save_dir, run_name.replace("/", "-"))
                                if not os.path.isdir(args.save_path):
                                    os.mkdir(args.save_path)

                                wandb.init(
                                # set the wandb project where this run will be logged
                                entity="large_depth_team",
                                project="large_depth_transfers",

                                # track hyperparameters and run metadata
                                config=args.__dict__
                                )
                                wandb.run.name = run_name

                                device = 'cuda' if torch.cuda.is_available() else 'cpu'
                                print("device = " + device)
                                if args.multiprocessing:
                                    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


                                best_acc = 0  # best test accuracy
                                start_epoch = 0  # start from epoch 0 or last checkpoint epoch

                                E = args.num_ens
                                # Set the random seed manually for reproducibility.
                                torch.manual_seed(args.seed)

                                # Data
                                if not args.save_base_shapes:
                                    print('==> Preparing data..')

                                    g = torch.Generator()
                                    g.manual_seed(args.seed)

                                    def seed_worker(worker_id):
                                        worker_seed = torch.initial_seed() % 2**32
                                        np.random.seed(worker_seed)
                                        random.seed(worker_seed)
                                    trainloader, testloader = load_data(args, generator=g, seed_worker=seed_worker)
                                    if len(batch_sizes) > 1 and max_updates == -1:
                                        # epochs x n_batches
                                        max_updates = args.epochs * len(trainloader) # calculate n updates based on first batch size
                                        args.epochs = 1000 # anyway it will break before
                                        print(f"Training for {max_updates} steps")


                                if args.coord_check:
                                    print('testing parametrization')
                                    os.makedirs('coord_checks', exist_ok=True)
                                    plotdir = 'coord_checks'
                                    coord_check(mup=True,
                                        lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, 
                                            nseeds=args.coord_check_nseeds, dataloader=trainloader, args=args, device=device, plotdir=plotdir, legend=False)
                                    coord_check(mup=False,
                                        lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes=args.load_base_shapes, 
                                            nseeds=args.coord_check_nseeds, dataloader=trainloader, args=args, device=device,plotdir=plotdir, legend=False)
                                    import sys; sys.exit()

                                if args.coord_check_depth:
                                    print('testing parametrization')
                                    os.makedirs('coord_checks_depth', exist_ok=True)
                                    plotdir = 'coord_checks_depth'
                                    # TODO: make list readable from file (now can only handle conv case)
                                    fpath = "source_code/base_file"
                                    base_shapes_list = [ 
                                        "conv_cifar10_d3.bsh",
                                        "conv_cifar10_d6.bsh",
                                        "conv_cifar10_d9.bsh",
                                        "conv_cifar10_d12.bsh",
                                        "conv_cifar10_d30.bsh"
                                    ]
                                    base_shapes_list = [os.path.join(fpath, name) for name in base_shapes_list]
                                    coord_check_depth(mup=True,
                                        lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes_list=base_shapes_list, 
                                            nseeds=args.coord_check_nseeds, dataloader=trainloader, args=args, depths_mult=[1, 2, 3, 4, 10], device=device, plotdir=plotdir, legend=True)
                                    coord_check_depth(mup=False,
                                        lr=args.lr, optimizer=args.optimizer, nsteps=args.coord_check_nsteps, arch=args.arch, base_shapes_list=base_shapes_list, 
                                            nseeds=args.coord_check_nseeds, dataloader=trainloader, args=args, depths_mult=[1, 2, 3, 4, 10], device=device, plotdir=plotdir, legend=True)
                                    import sys; sys.exit()

                                # Model
                                print('==> Building model..')
                                nets = []
                                for e in range(E):
                                    torch.manual_seed(e)
                                    nets.append(get_model(args.arch, args.width_mult, args.depth_mult, args))
                                print(nets[0])

                                # need to run a save base shapes to get the 1k output channels
                                if args.save_base_shapes:
                                    print(f'saving base shapes at {args.save_base_shapes}')
                                    base_net = get_model(args.arch, 1, args.depth_mult, args)
                                    base_shapes = get_shapes(base_net)
                                    delta_net = get_model(args.arch, 4, args.depth_mult, args)
                                    delta_shapes = get_shapes(delta_net)
                                    
                                    print(base_net)
                                    print(delta_shapes)
                                    #delta_shapes = get_shapes(ConvNet(init_width=16, wm=5, depth_mult=args.depth_mult))
                                    make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
                                    # save_shapes(net, args.save_base_shapes)
                                    print('done and exit')
                                    import sys; sys.exit()


                                if args.multiprocessing == True:
                                    # assumes GPUs on a single node
                                    device_ids = [i for i in range(torch.cuda.device_count())]
                                    print("DEVICE IDs")
                                    print(device_ids)
                                    for net in nets:
                                        net = torch.nn.DataParallel(net, device_ids)
                                    
                                        #net.cuda() 
                                
                                nets = [net.to(device) for net in nets]

                                torch.manual_seed(args.seed)
                                torch.cuda.manual_seed(args.seed)
                                if args.load_base_shapes:
                                    print(f'loading base shapes from {args.load_base_shapes}')
                                    for e, net in enumerate(nets):
                                        set_base_shapes(net, args.load_base_shapes)
                                    print('done')
                                else:
                                    print(f'using standard parametrization')
                                    for e, net in enumerate(nets):
                                        set_base_shapes(net, None)
                                    print('done')
                                    args.parametr = 'sp'

                                if args.depth_scale_first != 1.0:
                                    [net.init_parameters_depth() for net in nets]

                                if args.resume:
                                    # Load checkpoint.
                                    print('==> Resuming from checkpoint..')
                                    #assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
                                    checkpoint = torch.load(os.path.join(args.save_path, f"model_ckpt_N_{args.width_mult}_last_.pth"))
                                    state = torch.load(os.path.join(args.save_path + f'/ckpt_N_{args.width_mult}_batches_{args.resume_epochs}_.pth'))
                                    nets_weights = checkpoint['nets']
                                    [net.load_state_dict(net_weights) for (net, net_weights) in zip(nets, nets_weights)]
                                    #best_acc = state['acc']
                                    start_epoch = state['epoch'] + 1



                                criterion = nn.CrossEntropyLoss()

                                optimizers = get_optimizers(nets, args)


                                if args.schedule:
                                    scheduler = partial(
                                        transformers.get_cosine_schedule_with_warmup,
                                        num_warmup_steps=args.warmup_steps,
                                        num_training_steps=args.epochs * len(trainloader),
                                        num_cycles=0.5,
                                    )
                                    schedulers = [scheduler(optimizer) for optimizer in optimizers]
                                elif args.warmup_steps > 0:
                                    def wu_scheduler(opt, warmup_steps):
                                        return torch.optim.lr_scheduler.LambdaLR(
                                                opt,
                                                lr_lambda=lambda step: min(
                                                    1.0, step / warmup_steps
                                                ),  # Linear warmup over warmup_steps.
                                                )
                                    schedulers = [wu_scheduler(optimizer, args.warmup_steps) for optimizer in optimizers]
                                else:
                                    schedulers = []

                                metrics = {'train_loss': [], 'ens_train_loss': [], 'test_loss': [], 'ens_test_loss': [], 'test_acc': [], 'ens_test_acc': [], 'train_acc': [], 'ens_train_acc': []}
                                batches_seen = 0
                                for epoch in range(start_epoch, start_epoch+args.epochs):
                                    metrics, batches_seen = train(epoch,batches_seen,nets,metrics, num_classes, trainloader, optimizers, criterion, device, schedulers, log=True, max_updates=max_updates)
                                    metrics = test(nets, metrics, num_classes, testloader, criterion, device)
                                    
                                    
                                    # if args.schedule:
                                    #     for scheduler in schedulers:
                                    #         scheduler.step()

                                
                                    # if args.schedule:
                                    #     for scheduler in schedulers:
                                    #         scheduler.step()

                                    print('Saving..')
                                    state = {
                                        'metrics': metrics,
                                        'epoch': epoch
                                    }
                                    if not os.path.isdir(args.save_path):
                                        os.mkdir(args.save_path)
                                    torch.save(state, args.save_path + f'/ckpt_N_{args.width_mult}_batches_{epoch}_.pth')    
                                    net_state = {'nets': [net.state_dict() for net in nets]}
                                    torch.save(net_state, args.save_path + f'/model_ckpt_N_{args.width_mult}_last_.pth')
                                    if batches_seen >= max_updates and max_updates!=-1:
                                        print("exiting")
                                        break

                                wandb.finish()
        

