#%%
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch

from torchvision.models import resnet50,resnet18
from peft_utils.custom_peft import get_custom_peft
from peft_utils.bilevel_trainer import adalora_bilevel_trainer
from data_helper import load_cifar10
import argparse


def main():
    parser = argparse.ArgumentParser(
                        prog='blo_resnet')

    parser.add_argument("--wd", dest="wd", type=float, default=0.0)
    parser.add_argument("--lr", dest="lr", type=float, default=1e-3)
    parser.add_argument("--rank", dest="rank", type=int, default=8)
    parser.add_argument("--epochs", dest="epochs", type=int, default=50)
    parser.add_argument("--final_cr", dest="final_cr", type=float, default=0.8)
    parser.add_argument("--dropout", dest="dropout", type=float, default=0.0)
    parser.add_argument("--task_name", dest="task_name", type=str, default='resnet_cifar10')
    parser.add_argument("--riemannian", dest="riemannian", type=str, default='')


    args = parser.parse_args()



    TARGET_MODULES = ['.conv']

    model = resnet50(pretrained=True)
    model.conv1 = torch.nn.Conv2d(
                        3, 64, kernel_size=3, stride=1, padding=1, bias=False,device=torch.device('cuda:2')
                    )
    model.maxpool = torch.nn.Identity()
    model.fc = torch.nn.Linear(model.fc.in_features,10,bias = True,device=torch.device('cuda:2'))
    model = model.to((torch.device('cuda:2')))

    train_loader,val_loader,_ = load_cifar10(128,1)

    model,layers = get_custom_peft(model,'blo',target_layer_names=TARGET_MODULES,rank=args.rank,alpha=32,lora_dropout=args.dropout)
    
    model.to(torch.device('cuda:2'))


    optimizer_UV = torch.optim.AdamW([p for l in layers for n,p in l.named_parameters() if '.s' not in n]+list(model.conv1.parameters())+list(model.fc.parameters())
                                    ,lr = args.lr,weight_decay=args.wd)
    scheduler_UV = torch.optim.lr_scheduler.ConstantLR(optimizer_UV)
    optimizers_and_schedulers = {
            "optimizer": optimizer_UV,
            "scheduler": scheduler_UV,
        }

    trainer = adalora_bilevel_trainer(
        model = model,
        train_dataset=train_loader,
        eval_dataset= val_loader,
        low_rank_layers=layers,
        max_epochs_ll=2,
        tau = 500,
        optimizer_and_scheduler=optimizers_and_schedulers,
        riemannian=args.riemannian,
        final_cr=args.final_cr,
        task_name=args.task_name,
        epochs = args.epochs
    )

    trainer.train()

if __name__ == '__main__':
    main()
