import torch 
import argparse
import numpy as np  
from utils  import get_network, set_dataset_specs, get_dataset, get_normalize_trans, cutmix
from torchvision import transforms      
import torchvision  
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm       
import torch.nn.functional as F 
import math     
import copy as copy 
import pickle as pkl        

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")       
@torch.no_grad()            
def eval(model, dl):

    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.


 
def train(args):    
    torch.manual_seed(args.seed)    
    np.random.seed(args.seed)       
    torch.backends.cudnn.deterministic = True       

    model = get_network(args, pretrained=False)       

    normalize, __ = get_normalize_trans(args)   
    
    _, ds_val, __, ____, _____ = get_dataset(args, None)

    if args.softlbl == True:    
        train_trans = transforms.Compose([transforms.Resize((args.input_size, args.input_size)),
                                           transforms.ToTensor(), normalize])
    else:
        train_trans = transforms.Compose([transforms.RandomResizedCrop(args.input_size),
                                transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])   
        

    ds_train = torchvision.datasets.ImageFolder(root=args.syn_data_root, transform=train_trans)     
                                                
    if args.softlbl:    
        softlbl = torch.load(args.softlbl_file) 
        ds_train.targets = softlbl.clone()       
        for i in range(len(ds_train.samples)):    
            ds_train.samples[i] = (ds_train.samples[i][0], softlbl[i])

        criterion = torch.nn.KLDivLoss(reduction="batchmean")       

    else:
        if args.mixtype == "cutmix":        
            aug_func = cutmix           
        else:
            aug_func = None
            
        criterion = torch.nn.CrossEntropyLoss() 
        if args.softlbl_file is not None:
            print("Warning: softlbl file is used for hard label training, ##########################################")      
            exit()  

    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)       
    val_dl = torch.utils.data.DataLoader(ds_val, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)     

    if args.sgd:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.mom)     
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.adamw_lr, weight_decay=args.adamw_weight_decay)       

    if args.cos == True:    
        if args.cos == True:
            sch = LambdaLR(optimizer,
                                lambda step: 0.5 * (1. + math.cos(math.pi * step / args.epochs)) if step <= args.epochs else 0, last_epoch=-1)
        else:
            sch = LambdaLR(optimizer,
                                lambda step: (1.0-step/args.epochs) if step <= args.epochs else 0, last_epoch=-1)

    
    save_dict = {}
    # for k, v in args.__dict__.items():      
    #     save_dict[k] = v        

    acc = 0
    best_acc = 0
    
    for epoch in range(args.epochs):    
        model.train()       
        for b_ind, (x, y) in tqdm(enumerate(train_dl), total=len(train_dl)):        
            x, y = x.to(device), y.to(device)         

            if args.softlbl != True and aug_func is not None:   #if hard lbl then augment 
                x = aug_func(x)        

            out = model(x)  

            if args.softlbl == True:    
                out = F.log_softmax(out/args.T, dim=1)
                partial_soft_label = F.softmax(y/args.T, dim=1)
                loss = criterion(out, partial_soft_label)       
            else:   
                loss = criterion(out, y)        

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()        
        
        sch.step()  
        
        if (epoch+1) % args.eval_every == 0:        
            del x, y, out        
            torch.cuda.empty_cache()    
            acc = eval(model, val_dl)       
            if acc > best_acc:  
                best_acc = acc
                save_dict['best_model'] = copy.deepcopy(model.state_dict())       

            print(f"Epoch: {epoch}, Val Acc: {np.round(acc, 2)}, Best Acc:{np.round(best_acc, 2)}, Loss: {np.round(loss.item(), 3)}")     
            torch.save(model.state_dict(), f"model_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}.pth")           

    acc = eval(model, val_dl)       
    print(f"Final acc: {np.round(best_acc, 2)}, Loss: {np.round(loss.item(), 3)}")

    save_dict["final_acc"] = acc    
    save_dict['best_acc'] = best_acc            
    save_dict['model'] = model.state_dict()     
    #save as piclke file    
    pkl.dump(save_dict, open(f"results_{args.scenario_name}_acc_{np.round(best_acc, 2)}_seed_{args.seed}.pkl", "wb"))


def train_softlbl_with_aug(args):    
    model = get_network(args, pretrained=False)       
    model_teacher = get_network(args, pretrained=True) 
    model_teacher.eval()    

    for p in model_teacher.parameters():    
        p.requires_grad = False 

    normalize, __ = get_normalize_trans(args)   
    
    _, ds_val, __, ____, _____ = get_dataset(args, None)


    train_trans = transforms.Compose([transforms.RandomResizedCrop(args.input_size),    
                                        transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])        

    ds_train = torchvision.datasets.ImageFolder(root=args.syn_data_root, transform=train_trans)     
                                                

    criterion = torch.nn.KLDivLoss(reduction="batchmean")       

    if args.mixtype == "cutmix":        
        aug_func = cutmix           
    else:
        aug_func = None


    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)       
    val_dl = torch.utils.data.DataLoader(ds_val, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)     

    if args.sgd:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.mom)     
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.adamw_lr, weight_decay=args.adamw_weight_decay)       

    if args.cos == True:    
        if args.cos == True:
            sch = LambdaLR(optimizer,
                                lambda step: 0.5 * (1. + math.cos(math.pi * step / args.epochs)) if step <= args.epochs else 0, last_epoch=-1)
        else:
            sch = LambdaLR(optimizer,
                                lambda step: (1.0-step/args.epochs) if step <= args.epochs else 0, last_epoch=-1)

    
    save_dict = {}
    for k, v in args.__dict__.items():      
        save_dict[k] = v        
    
    best_acc = 0
    for epoch in range(args.epochs):    
        model.train()       
        optimizer.zero_grad()   
        for b_ind, (x, y) in tqdm(enumerate(train_dl), total=len(train_dl)):        
            x, y = x.to(device), y.to(device)         

            if aug_func is not None: 
                x = aug_func(x)        

            out = model(x)  
            with torch.no_grad():
                out_teacher = model_teacher(x).detach()     

            out = F.log_softmax(out/args.T, dim=1)
            partial_soft_label = F.softmax(out_teacher/args.T, dim=1)
            loss = criterion(out, partial_soft_label) / args.grad_acc_steps   
     
            loss.backward() 

            if (b_ind+1) % args.grad_acc_steps == 0:
                optimizer.step()        
                optimizer.zero_grad()  
        
        sch.step()  
        
        if (epoch+1) % args.eval_every == 0:        
            acc = eval(model, val_dl)       
            if acc > best_acc:  
                best_acc = acc
                save_dict['best_model'] = model.state_dict()

            print(f"Epoch: {epoch}, Val Acc: {np.round(acc, 2)}, Best Acc:{np.round(best_acc, 2)}, Loss: {np.round(loss.item(), 3)}")     
            # torch.save(model.state_dict(), f"model_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}.pth")           

    print(f"Final acc, Val Acc: {np.round(best_acc, 2)}, Loss: {np.round(loss.item(), 3)}")

    save_dict["final_acc"] = acc        
    save_dict['model'] = model.state_dict()     
    #save as piclke file    
    pkl.dump(save_dict, open(f"results_{args.scenario_name}_acc_{np.round(best_acc, 2)}.pkl", "wb"))


def create_patches(args):
    pass 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()        
    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    
    parser.add_argument("--exp_num", type=int, default=3)    
    parser.add_argument("--grad_acc_steps", type=int, default=1)    
    
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--input_size", type=int, default=224)
    parser.add_argument("--diff_input_size", type=int, default=512)
    parser.add_argument("--val_ipc", type=int, default=50)
    parser.add_argument("--factor", type=int, default=4)
    parser.add_argument("--root_dir", type=str, default='/home/public/')   
    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=1000)  
    parser.add_argument("--lr", type=float, default=0.1)        
    parser.add_argument("--wd", type=float, default=1e-4)        
    parser.add_argument("--mom", type=float, default=.9)      

    parser.add_argument('--adamw_lr', type=float, default=0.001)
    parser.add_argument('--adamw_weight_decay', type=float, default=0.01)  

    parser.add_argument("--dataset_name_dict", type=str) 
    parser.add_argument("--syn_data_root", type=str, required=True) 
    parser.add_argument("--softlbl_file", type=str, default=None)
    parser.add_argument("--softlbl", action="store_true", default=False)     

    parser.add_argument("--use_teacher", action="store_true", default=False)     

    parser.add_argument("--mixtype", type=str, default=None)  
    parser.add_argument("--cos", action="store_true", default=False)     
    parser.add_argument("--sgd", action="store_true", default=False)     
    parser.add_argument("--eval_every", type=int, default=10)  
    parser.add_argument("--T", type=float, default=20)  
    parser.add_argument("--scenario_name", type=str, required=True)
    parser.add_argument("--model_ckpt", type=str)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--workers", type=int, default=8)
    parser.add_argument("--seed_offset", type=int, default=0)
    args = parser.parse_args()  

    set_dataset_specs(args)
    seeds = np.arange(args.exp_num) + args.seed_offset  

    print(f"######## Seeds: {seeds} ########")                
    
    for exp_num in range(len(seeds)):       
        args.seed = seeds[exp_num]
        if args.use_teacher:   
            print("######## Training with teacher ########")       
            train_softlbl_with_aug(args)
        else:          
            print("######## Training without teacher ########")       
            train(args) 
