import torch
import math


from torchmeta.utils.data import BatchMetaDataLoader
from options import args_parser

from maml.datasets import get_benchmark_by_name, GetTaskPool, TaskPoolDataset,MetaProxTaskPoolDataset
from maml.metalearners import ModelAgnosticMetaLearning, MetaMinibatchProx, iMAML, FOMuML
from tqdm import tqdm
from torch.utils.data import DataLoader
from maml import utils
from maml.hessianfree import HessianFree

import wandb
wandb.login()

def main(args):
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else "cpu")
    
    train_task_pool_file = f'train_{args.dataset}_{args.num_ways}-way_{args.num_shots}-shots_{args.num_shots_test}-test_taskNumber_{args.num_tasks}'
    val_task_pool_file = f'val_{args.dataset}_{args.num_ways}-way_{args.num_shots}-shots_{args.num_shots_test}-test_taskNumber_{args.test_num_tasks}'

    benchmark = get_benchmark_by_name(args.dataset,
                                            args.folder,
                                            args.num_ways,
                                            args.num_shots,
                                            args.num_shots_test,
                                            hidden_size=args.hidden_size)
    
    if utils.task_pool_exists(train_task_pool_file):
        meta_train_task_pool = utils.load_task_pool(train_task_pool_file)
    else:
        meta_train_task_pool = GetTaskPool(BatchMetaDataLoader(benchmark.meta_train_dataset,
                                                            batch_size=1,
                                                            shuffle=True,
                                                            num_workers=args.num_workers,
                                                            pin_memory=True),
                                        args.num_tasks)
        utils.save_task_pool(meta_train_task_pool, train_task_pool_file)

    if utils.task_pool_exists(val_task_pool_file):
        meta_val_task_pool = utils.load_task_pool(val_task_pool_file)
    else:
        meta_val_task_pool = GetTaskPool(BatchMetaDataLoader(benchmark.meta_val_dataset,
                                                            batch_size=1,
                                                            shuffle=True,
                                                            num_workers=args.num_workers,
                                                            pin_memory=True),
                                        args.test_num_tasks)
        utils.save_task_pool(meta_val_task_pool, val_task_pool_file)
    
    train_dataset = TaskPoolDataset(meta_train_task_pool)
    metaprox_train_dataset = MetaProxTaskPoolDataset(meta_train_task_pool)
    val_dataset = TaskPoolDataset(meta_val_task_pool)

    if args.dataset == 'omniglot':
        my_collate_fn = utils.omniglot_collate_fn
    elif args.dataset == 'sinusoid':
        my_collate_fn = utils.sinusoid_collate_fn
    
    train_dataloader = DataLoader(train_dataset, 
                                  batch_size=args.batch_size, 
                                  shuffle=True, 
                                  collate_fn=my_collate_fn,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    metaprox_train_dataloader = DataLoader(metaprox_train_dataset, 
                                  batch_size=args.batch_size, 
                                  shuffle=True, 
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    val_dataloader = DataLoader(val_dataset, 
                                 batch_size=args.batch_size, 
                                 shuffle=False, 
                                 collate_fn=my_collate_fn,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    meta_optimizer = torch.optim.SGD(benchmark.model.parameters(), lr=args.meta_lr, momentum=0.9)
    if args.algorithm in ['MAML', 'DynamicMAML','FOMAML','DynamicFOMAML','MetaSGD','DynamicMetaSGD']:
        first_order = 'FOMAML' in args.algorithm
        beta = args.beta if 'Dynamic' in args.algorithm else 1.0
        per_param_step_size = 'MetaSGD' in args.algorithm
        learn_step_size = 'MetaSGD' in args.algorithm
        metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                                meta_optimizer,
                                                first_order=first_order,
                                                num_adaptation_steps=args.num_steps,
                                                per_param_step_size = per_param_step_size,
                                                learn_step_size = learn_step_size,
                                                step_size=args.step_size,
                                                beta = beta,
                                                dyfac = args.batch_size/args.dyfac,
                                                loss_function=benchmark.loss_function,
                                                device=device)
    if args.algorithm in ['iMAML','DynamiciMAML']:
        beta = args.beta if 'Dynamic' in args.algorithm else 1.0
        metalearner = iMAML(benchmark.model,
                            meta_optimizer,
                            num_adaptation_steps=args.num_steps,
                            step_size=args.step_size,
                            lamda=args.lamda,
                            n_cg=args.cg_steps,
                            beta = beta,
                            dyfac = args.batch_size/args.dyfac,
                            loss_function=benchmark.loss_function,
                            device=device)
    if args.algorithm in ['FOMuML','DynamicFOMuML']:
        beta = args.beta if 'Dynamic' in args.algorithm else 1.0
        metalearner = FOMuML(benchmark.model,
                                optimizer=meta_optimizer,
                                lamda = args.lamda,
                                num_adaptation_steps=args.num_steps,
                                step_size=args.step_size,
                                beta = beta,
                                dyfac = args.dyfac,
                                loss_function=benchmark.loss_function,
                                device=device)
    if args.algorithm in ['MetaProx','DynamicMetaProx']:
        beta = args.beta if 'Dynamic' in args.algorithm else 1.0
        metalearner = MetaMinibatchProx(benchmark.model,
                                        meta_optimizer,
                                        lamda = args.lamda,
                                        num_adaptation_steps=args.num_steps,
                                        step_size=args.step_size,
                                        meta_lr = args.meta_lr,
                                        beta = beta,
                                        dyfac = int(args.batch_size/args.dyfac),
                                        loss_function=benchmark.loss_function,
                                        device=device)
    best_value = None

    for epoch in tqdm(range(args.num_epochs), desc="Epochs Progress", leave=True, dynamic_ncols=True):
        if args.algorithm in ['MAML', 'DynamicMAML','FOMuML','FOMAML','MetaSGD','iMAML','DynamiciMAML','DynamicFOMAML','DynamicMetaSGD','DynamicFOMuML']:
            train_loss = metalearner.train(
                train_dataloader
            )
        if args.algorithm in ['MetaProx','DynamicMetaProx']:
            train_loss = metalearner.train( 
                metaprox_train_dataloader, 
                train_dataloader,
                epoch            
                )
        wandb.log({"train loss": train_loss, 'epoch': epoch})
        
        # Validation/Test step
        val_loss = metalearner.evaluate(
            val_dataloader,
        )
        
        wandb.log({"test loss": val_loss, 'epoch': epoch})

        loss_diff = abs(train_loss - val_loss)
        
        wandb.log({"generalization error":loss_diff,'epoch':epoch})
        
        if (best_value is None) or (best_value > val_loss):
            best_value = val_loss
        tqdm.write(
        f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Test Loss = {val_loss:.4f}, |Diff| = {loss_diff:.4f}, Best Value = {best_value:.4f}"
    )

    if hasattr(benchmark.meta_train_dataset, 'close'):
        benchmark.meta_train_dataset.close()
        benchmark.meta_val_dataset.close()

    
if __name__ == '__main__':
    # utils.set_seed(666)
    seeds = [666]
    # seeds = [333,666,999]
    args = args_parser()
    if args.dataset == 'omniglot':
        project = 'MetaLearning'
    elif args.dataset == 'sinusoid':
        project = 'sinusoid'
    for run_time in range(len(seeds)):
        # utils.set_seed(seeds[run_time])
        utils.set_seed(666)
        if args.algorithm in ['MAML','MetaSGD','FOMAML','FOMuML']:
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_{args.num_shots}-shot_{args.num_shots_test}_test_{args.num_steps}-numSteps_{args.num_tasks}_taks_ilr_{args.step_size}_olr_{args.meta_lr}"
        if args.algorithm in ['iMAML']:
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_{args.num_steps}-numSteps_{args.num_tasks}_taks_ilr_{args.step_size}_olr_{args.meta_lr}_cg_{args.cg_steps}"
        if args.algorithm == 'FOMAML':
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_olr_{args.meta_lr}_ilr_{args.step_size}_{args.num_steps}-numSteps"
        if args.algorithm == 'MetaProx':
            # exname = f"Algorithm_{args.algorithm}_{args.num_ways}-way_{args.num_shots}-shot_{args.num_steps}-numSteps"
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_{args.num_steps}-numSteps_ilr_{args.step_size}_olr_{args.meta_lr}_{args.lamda}_lamda"

        if args.algorithm in ['DynamicMAML','DynamiciMAML','DynamicFOMAML','DynamicMetaSGD','DynamicFOMuML']:
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_{args.num_steps}-numSteps_{args.num_tasks}_taks_{args.beta}_beta_{args.dyfac}_fac"
        
        if args.algorithm == 'DynamicMetaProx':
            exname = f"{args.dataset}_Algorithm_{args.algorithm}_{args.num_steps}-numSteps_{args.beta}_beta_{args.dyfac}_fac_ilr_{args.step_size}_olr_{args.meta_lr}"
        group_name = f"{args.dataset}"
        # wandb.init(config=args, project='MetaLearning', name=exname, group=group_name,mode='disabled') 
        wandb.init(config=args, project=project, name=exname, group=group_name) 
        main(args)
     