import torch

### set seeds
torch.manual_seed(0)
import numpy as np

np.random.seed(0)
import random

random.seed(0)

### clean warnings
import warnings

warnings.filterwarnings("ignore")

import argparse

# from src.optimizer_DLRT.dlrt_optimizer import DLRT_Optimizer
import trainer
# from src.training import compresser
# from src.training import profiler_trainer

from dataset_utils import choose_dataset
# from model_utils import choose_model
from models.vgg import VGG,VGG_types
from models.lenet5 import Lenet5
from models.alexnet import AlexNet


def main():
    ###################### parser creation  ######################
    parser = argparse.ArgumentParser(description='Pytorch dlrt training for vgg of imagenet')
    # Arguments for network training
    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs', type=int, default=60, metavar='N', help='number of epochs to train (default: 100)')
    parser.add_argument('--epochs_ft', type=int, default=30, metavar='N', help='number of epochs to train (default: 0)')
    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate for dlrt optimizer (default: 0.05)')
    parser.add_argument('--tau', type=float, default=0.2, metavar='tau',
                        help='cutting rank for dlrt optimizer (default: 0.3)')
    parser.add_argument('--wd', type=float, default=0.001, metavar='wd',
                        help='weight decay on S and weights')
    parser.add_argument('--momentum', type=float, default=0.1, metavar='MOMENTUM', help='momentum (default: 0.1)')
    parser.add_argument('--workers', type=int, default=1, metavar='WORKERS',
                        help='number of workers for the dataloaders (default: 1)')
    parser.add_argument('--deco', type=str, default='cp', metavar='DECO',
                        help='Decomposition to use (cp or tucker pr mat)')
    # Arguments for network save n load
    parser.add_argument('--save_weights', type=bool, default=False, metavar='SAVE_WEIGHTS',
                        help='save the weights of the best validation model during the run (default: True)')
    parser.add_argument('--save_progress', type=bool, default=False, help='save progress csv (TEST)')
    parser.add_argument('--load_weights', type=bool, default=False, help='load standard weights for the model (TEST)')
    parser.add_argument('--load_model_path', type=str, default=None, metavar='LOAD_MODEL_PATH',
                        help='Loads the model given the full path including the filename. Basepath is where main.py is '
                             'located.'
                             'The user needs to take care to load the correct model for the dataset (default: None)')

    parser.add_argument("--net_name", default='lenet5',
                        choices=["lenet5", "vgg16", 'lenet5_rgb', 'alexnet', 'alexnet_nobn', 'lenet5_blackandwhite',
                                 'resnet18', 'resnet50', 'vit_b_16', "toaddother"])
    parser.add_argument("--dataset_name", default='mnist',
                        choices=["mnist", "cifar10", "fashion_mnist", 'svhn', 'cifar10_b&w', 'imagenet'])
    parser.add_argument('--cv_run', type=int, default=0,
                        help='number of cross validation run to add to savename (default: 0)')

    # Arguments for Low-Rank Discretization
    parser.add_argument('--chain_init', type=bool, default=False, help='add chain initialization (TEST)')
    parser.add_argument('--tucker', type=bool, default=False, help='add tucker convolution (TEST)')
    parser.add_argument('--adaptive', type=bool, default=False, help='add tucker convolution (TEST)')
    parser.add_argument('--compresser', type=bool, default=False, help='layerwise compression (TEST)')
    parser.add_argument('--mat_dlrt', type=bool, default=False,
                        help='add matrix linear dlrt layers (TEST) (default: True)')
    parser.add_argument('--baseline', type=bool, default=False,
                        help='add matrix linear dlrt layers (TEST) (default: True)')
    parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
    parser.add_argument('--save_name', type=str, default='comparison', help='savename of results')
    # Misc Arguments
    parser.add_argument('--profiler', type=bool, default=False, help='toggle Timing profiler')
    parser.add_argument('--datapath', type=str, default="../../../../data02/zangrando/", help='path to imagenet folder')

    args = parser.parse_args()

    # setup cuda
    device = args.device if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    args.device = device

    def accuracy(outputs, labels):
        return torch.sum(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=torch.float16))

    criterion = torch.nn.CrossEntropyLoss()

    # -------- Network Selection -----------
    if args.net_name == 'lenet5':
        f = Lenet5(args)
        f.to(args.device)
    elif args.net_name == 'vgg16':
        f = VGG(VGG_types['VGG16'],3,32,32,256,10,args = args)
        f.to(args.device)
    elif args.net_name == 'alexnet':
        f = AlexNet(10,args.device,args = args)
        f.to(args.device)
    
    # x = torch.randn((10,3,224,224))
    # print(f.lr_model)
    # print(f(x))
    # input()
    
    # -------- Dataset Selection -----------
    train_loader, val_loader, test_loader = choose_dataset(dataset_name=args.dataset_name, batch_size=args.batch_size,
                                                           num_workers=args.workers, datapath=args.datapath)

    # -------- Optimizer Selection ---------
    optimizer = torch.optim.SGD(f.parameters(),lr  =args.lr,momentum= args.momentum,weight_decay= args.wd)

    # -------- LR Scheduler Selection ---------
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)

    # -------- Trainer Selection ---------

    path = f'./results/'
    save_name = f'vanilla_{args.net_name}_{args.deco}_{args.dataset_name}_tau{args.tau}_mom{args.momentum}_lr{args.lr}_{args.cv_run}'

    print(
        f'TRAINING {args.net_name} ON {args.dataset_name} with parameters lr{args.lr},tau {args.tau},baseline {args.baseline}')
    trainer.train(f, optimizer=optimizer, criterion=criterion, train_loader=train_loader,
                    epoch_status_bar=True, validation_loader=val_loader, test_loader=test_loader,
                    metric=accuracy, epochs=args.epochs, device=args.device, path=path,
                    save_weights=args.save_weights, save_progress=args.save_progress,
                    scheduler=scheduler, save_name=save_name)


if __name__ == '__main__':
    # from test import test_alexnet_cifar10_tucker_fr
    # t = test_alexnet_cifar10_tucker_fr()

    main()
