import argparse
from utils  import get_whole_dataset, set_dataset_specs, get_network, get_syn_ds, cutmix
import torch         
from tqdm import tqdm       
import numpy as np      
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F    
import torch.nn as nn       
import math 
from synth_utils import synth_dataset_parallel       
import torch.multiprocessing as mp      

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       


@torch.no_grad()    
def eval(model, dl):
    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in enumerate(dl):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.


def train_augmented(args):

    model_teacher = get_network(args)       
    model_teacher.load_state_dict(torch.load(args.model_ckpt))  
    model_teacher.eval()        
    for p in model_teacher.parameters():        
        p.requires_grad = False 


    model = get_network(args)       
    ___, ds_val =  get_whole_dataset(args)
    
    val_dl = torch.utils.data.DataLoader(ds_val, batch_size=args.batch_size, shuffle=False)     

    if args.init_acc:
        acc_teacher = eval(model_teacher, val_dl)       
        print(f"Teacher Acc: {np.round(acc_teacher, 2)}")       

    gen_epochs = args.epochs // args.seed_num       
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.adamw_lr, weight_decay=args.adamw_weight_decay)       
    # lr_fun = lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.epochs / 2)) if step <= args.epochs else 0
    lr_fun = lambda step: 0.5 * (1.0 + math.cos(math.pi * step / gen_epochs / 2)) if step <= args.epochs else 0
    scheduler = LambdaLR(optimizer, lr_fun, last_epoch=-1)

    loss_function_kl = nn.KLDivLoss(reduction="batchmean")
    seeds = [int(i) for i in range(args.seed_num)]

    for epoch in range(args.epochs):    

        if epoch == 0 or epoch % args.generate_every == 0:
        # if epoch == 0:
            gen_seed = seeds[epoch % args.seed_num]
            model.to("cpu")     
            torch.cuda.empty_cache()        
            synth_dataset_parallel(args, gen_seed)
            ds_train = get_syn_ds(args, args.chopped_save_dir) 
            train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True)       
            
        model.train().to(device)        
        for b_ind, (x, y) in tqdm(enumerate(train_dl), total=len(train_dl)):        
            x, y = x.to(device), y.to(device)       
            # x = cutmix(x, args.cutmix)[0]

            out = model(x)  
            soft_student = F.log_softmax(out / args.T, dim=1)       
            with torch.no_grad():        
                out_teacher = model_teacher(x)  
                soft_teacher = F.softmax(out_teacher / args.T, dim=1)

            loss = loss_function_kl(soft_student, soft_teacher)

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()       

        del x, y
        torch.cuda.empty_cache()        
        
        if (epoch+1) % args.seed_num == 0:  
            scheduler.step()

        if (epoch+1) % 10 == 0:
            acc = eval(model, val_dl)       
            print(f"Epoch: {epoch}, Val Acc: {np.round(acc, 2)}, Loss: {np.round(loss.item(), 3)}")     
            

        # torch.save(model.state_dict(), f"model_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}.pth")  
    
    acc = eval(model, val_dl)       
    acc = str(np.round(acc, 2)  )
    print(f"Final Val Acc: {acc}")         
    torch.save(model.state_dict(), f"model_method_variablenum_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_acc_{acc}.pth")       


if __name__ == "__main__":      
    mp.set_start_method('fork')

    parser = argparse.ArgumentParser()        
    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)   
    parser.add_argument("--epochs", type=int, default=300)   
    parser.add_argument("--generate_every", type=int, default=5)   
    parser.add_argument("--seed_num", type=int, default=10)   

    parser.add_argument("--st_cls", type=int, default=0)    
    
    parser.add_argument("--end_cls", type=int, default=None)
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)

    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--diff_input_size", type=int, default=512)      
    
    parser.add_argument("--factor", type=int, default=4)
    parser.add_argument("--ngpu", type=int, default=4)


    parser.add_argument("--ipc", type=int, default=2)
    parser.add_argument("--root_dir", type=str, default='/home/public/')   
    parser.add_argument("--chopped_save_dir", type=str, required=True)   
    parser.add_argument("--collage_save_dir", type=str, required=True)   
    parser.add_argument("--emb_root", type=str, required=True)   


    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--model_ckpt", type=str, required=True)
    
    parser.add_argument("--init_acc", action="store_true")

    parser.add_argument("--adamw-lr", type=float, default=0.001, help="adamw learning rate")
    parser.add_argument("--adamw-weight-decay", type=float, default=0.01, help="adamw weight decay")
    parser.add_argument("--T", type=float, default=20., help="temperature")      
    parser.add_argument("--cutmix", type=float, default=1.0, help="cutmix alpha, cutmix enabled if > 0. (default: 1.0)")
    parser.add_argument("--max-scale-crops", type=float, default=1, help="argument in RandomResizedCrop")


    args = parser.parse_args()  
    set_dataset_specs(args)
    
    train_augmented(args) 