import argparse
from utils  import get_whole_dataset, set_dataset_specs, get_network, get_rded_syn_ds, cutmix
import torch         
from tqdm import tqdm       
import numpy as np      
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F    
import torch.nn as nn       
import math 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       

def eval(model, dl):
    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in enumerate(dl):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.


def train_rded(args):

    model_teacher = get_network(args)       
    model_teacher.load_state_dict(torch.load(args.model_ckpt))  
    model_teacher.eval()        
    for p in model_teacher.parameters():        
        p.requires_grad = False 


    model = get_network(args)       
    ___, ds_val =  get_whole_dataset(args)
    ds_train = get_rded_syn_ds(args) 

    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)       
    val_dl = torch.utils.data.DataLoader(ds_val, batch_size=args.batch_size, shuffle=False)     

    if args.init_acc:
        acc_teacher = eval(model_teacher, val_dl)       
        print(f"Teacher Acc: {np.round(acc_teacher, 2)}")       

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.adamw_lr, weight_decay=args.adamw_weight_decay)       
    lr_fun = lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.epochs / 2)) if step <= args.epochs else 0
    scheduler = LambdaLR(optimizer, lr_fun, last_epoch=-1)

    loss_function_kl = nn.KLDivLoss(reduction="batchmean")

    for epoch in range(args.epochs):    
        model.train()       
        for b_ind, (x, y) in tqdm(enumerate(train_dl), total=len(train_dl)):        
            x, y = x.to(device), y.to(device)       
            # x = cutmix(x, args.cutmix)[0]
            # print(x.shape)

            out = model(x)  
            soft_student = F.log_softmax(out / args.T, dim=1)       
            with torch.no_grad():        
                out_teacher = model_teacher(x)  
                soft_teacher = F.softmax(out_teacher / args.T, dim=1)

            loss = loss_function_kl(soft_student, soft_teacher)

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()       

        scheduler.step()
        if (epoch+1) % 100 == 0:
            acc = eval(model, val_dl)       
            print(f"Epoch: {epoch}, Val Acc: {np.round(acc, 2)}, Loss: {np.round(loss.item(), 3)}")     

        # torch.save(model.state_dict(), f"model_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}.pth")  
    
    acc = eval(model, val_dl)       
    acc = str(np.round(acc, 2)  )
    print(f"Final Val Acc: {acc}")         
    torch.save(model.state_dict(), f"model_method_RDED_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_acc_{acc}.pth")       


if __name__ == "__main__":      
    parser = argparse.ArgumentParser()        
    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)   
    parser.add_argument("--epochs", type=int, default=300)   

    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--batch_size", type=int, default=100)
    
    parser.add_argument("--factor", type=int, default=2)

    parser.add_argument("--root_dir", type=str, default='/home/public/')   
    parser.add_argument("--collage_save_dir", type=str)   

    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--model_ckpt", type=str, required=True)
    
    parser.add_argument("--init_acc", action="store_true")

    parser.add_argument("--adamw-lr", type=float, default=0.001, help="adamw learning rate")
    parser.add_argument("--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay")
    parser.add_argument("--T", type=float, default=20., help="temperature")      
    parser.add_argument("--cutmix", type=float, default=1.0, help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)")
    parser.add_argument("--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop")
    
    args = parser.parse_args()  
    set_dataset_specs(args)
    
    train_rded(args) 