import torch 
import argparse
import numpy as np  
from utils  import get_network, set_dataset_specs, get_dataset, get_normalize_trans
from torchvision import transforms      
from tqdm import tqdm       

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       

@torch.no_grad()    
def eval(model, dl):

    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in enumerate(dl):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.



def train(args):    
    model = get_network(args, pretrained=False)       

    normalize, ____ = get_normalize_trans(args)   
    train_trans = transforms.Compose([transforms.RandomResizedCrop(args.input_size), 
                                transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
    
    ds_train, ds_val, __, ____, _____ = get_dataset(args, train_trans)

    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)       
    val_dl = torch.utils.data.DataLoader(ds_val, batch_size=args.batch_size, shuffle=False)     

    criterion = torch.nn.CrossEntropyLoss() 
    # optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)        
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.mom)     
    sch = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 150], gamma=0.1)           

    for epoch in range(args.epochs):    
        model.train()       
        for b_ind, (x, y) in tqdm(enumerate(train_dl), total=len(train_dl)):        
            x, y = x.to(device), y.to(device)         
            out = model(x)  

            loss = criterion(out, y)        

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()        
        
        sch.step()  
        
        acc = eval(model, val_dl)       
        print(f"Epoch: {epoch}, Val Acc: {np.round(acc, 2)}, Loss: {np.round(loss.item(), 3)}")     
        torch.save(model.state_dict(), f"model_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}.pth")           


def create_patches(args):
    pass 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()        
    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    
    
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--diff_input_size", type=int, default=512)
    parser.add_argument("--val_ipc", type=int, default=50)
    parser.add_argument("--factor", type=int, default=4)
    parser.add_argument("--root_dir", type=str, default='/home/public/')   
    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=90)  
    parser.add_argument("--lr", type=float, default=0.1)        
    parser.add_argument("--wd", type=float, default=1e-4)        
    parser.add_argument("--mom", type=float, default=.9)        

    parser.add_argument("--dataset_name_dict", type=str) 

    args = parser.parse_args()  

    set_dataset_specs(args)
    train(args) 


