import torch
import argparse
import surrogate_models
import pretrained_models
from datasets import create_dataset
import itertools
from utils.util import batch_gen, Discounted_UCB1_tuned
import torch.nn.functional as F
import os
import sys
import time
import timm


def update(idx, model, dataset, type = 'both'): #type in [both, to_surrogate, from_surrogate]
    global loaders
    global models
    global surrogate
    global predictors_to_surrogate
    global predictors_from_surrogate
    global device

    if '+' in dataset:
        input_original, input_surrogate, _ = next(loaders[dataset])
        input_original, input_surrogate = input_original.to(device), input_surrogate.to(device)
    else:
        input, _ = next(loaders[dataset])
        input = input.to(device)
        input_original = input
        input_surrogate = input

    with torch.no_grad():
        original_feature = models[model](input_original)['feature']

    surrogate_feature = surrogate(input_surrogate, surrogate.get_embeddings(idx))

    predicted_surrogate_feature = predictors_to_surrogate[model](original_feature)
    predicted_original_feature = predictors_from_surrogate[model](surrogate_feature)

    sim_to_surrogate = F.cosine_similarity(surrogate_feature, predicted_surrogate_feature).mean()
    sim_from_surrogate = F.cosine_similarity(original_feature, predicted_original_feature).mean()

    if type == 'both':
        loss = -0.5 * (sim_to_surrogate + sim_from_surrogate)
    elif type == 'to_surrogate':
        loss = -sim_to_surrogate
    elif type == 'from_surrogate':
        loss = -sim_from_surrogate
    else:
        raise Exception('invalid type, must be both, to_surrogate or from_surrogate')
            
    loss.backward()

    if type == 'both':
        return -loss.item()
    elif type == 'to_surrogate':
        return sim_to_surrogate.item()
    elif type == 'from_surrogate':
        return sim_from_surrogate.item()

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='training a surrogate model with embeddings of pre-trained models, \
                                    to be later used for downstream model selection')

    parser.add_argument('--name', type=str, default='noname', help='the name of the trial')

    #surrogate model
    parser.add_argument('--surrogate_type', type=str, default='FeatureWeighting+resnet18', help='architecture for the surrogate model')
    parser.add_argument('--temperature', type=float, default=1.0, help = 'temperature for the sigmoid over embeddings')
    parser.add_argument('--fix_backbone', action='store_true')

    #datasets
    parser.add_argument('--datasets', type=str, default='ImageNet_val', 
                        help='datasets for training the surrogate model, separated with :')

    #pre-trained models
    parser.add_argument('--pretrained_models', type=str, 
                        default='ResNet18_Weights.IMAGENET1K_V1:EfficientNet_B0_Weights.IMAGENET1K_V1:GoogLeNet_Weights.IMAGENET1K_V1:Swin_T_Weights.IMAGENET1K_V1', 
                        help='list of pre-trained models (i.e. the candiates for model selection); separated with :')
    parser.add_argument('--pretrained_datasets', type=str, default='ImageNet_val:ImageNet_val:ImageNet_val:ImageNet-swin-t_val',
                        help='corresponding datasets for each model')


    #multi-armed bandit
    parser.add_argument('--bandit', action='store_true')
    parser.add_argument('--bandit_gamma', type=float, default=0.99)
    parser.add_argument('--separate_arms', action='store_true')

    parser.add_argument('--min_sim', action='store_true')

    #optimization
    parser.add_argument('--batch_size', type=int, default = 128)
    parser.add_argument('--steps', type=int, default = 40000)
    parser.add_argument('--num_workers', type=int, default=4)

    parser.add_argument('--optimizer', type=str, default='SGD', 
                        help='type of the optimizer, in [SGD, AdamW]')
    parser.add_argument('--learning_rate', type=float, default=0.1)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=5e-4)

    parser.add_argument('--lr_decay', type=float, default=0.1)
    parser.add_argument('--step_size', type=int, default=10000)

    parser.add_argument('--log_freq', type=int, default=1000)
    parser.add_argument('--save_freq', type=int, default=10000)

    parser.add_argument('--overwrite', action='store_true')


    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    save_dir = './checkpoint/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_dir = os.path.join(save_dir, args.name)
    if os.path.exists(save_dir):
        if not args.overwrite:
            raise Exception(save_dir + ' already exists; activate --overwrite to overwrite')
    else:
        os.makedirs(save_dir)

    n_pretrain_model = len(args.pretrained_models.split(':'))

    #prepare surrrogate
    surrogate = surrogate_models.create(surrogate_type = args.surrogate_type, temperature = args.temperature)
    surrogate.init_embeddings(n_pretrain_model)
    surrogate.to(device)

    if args.fix_backbone:
        surrogate.eval()
        surrogate.set_backbone(requires_grad = False)
    else:
        surrogate.train()


    #prepare datasets
    datasets = {}
    for dataset in args.datasets.split(':'):
        datasets[dataset] = create_dataset(dataset)


    #prepare pretrained models (candidates for model selection)
    #and their corresponding data loaders
    #and their corresponding predictors
    models = {}
    loaders = {}
    predictors_from_surrogate = {}
    predictors_to_surrogate = {}

    params = [surrogate.parameters()]

    args.pretrained_models = args.pretrained_models.split(':')
    args.pretrained_datasets = args.pretrained_datasets.split(':')

    for model, dataset in zip(args.pretrained_models, args.pretrained_datasets):
        models[model] = pretrained_models.get_feature_extractor(model)
        models[model].to(device)
        models[model].eval()

        print ('output_dim:', models[model].output_dim)

        predictors_from_surrogate[model] = torch.nn.Linear(surrogate.output_dim, models[model].output_dim)
        predictors_to_surrogate[model] = torch.nn.Linear(models[model].output_dim, surrogate.output_dim)

        predictors_from_surrogate[model].to(device)
        predictors_from_surrogate[model].eval()

        predictors_to_surrogate[model].to(device)
        predictors_to_surrogate[model].eval() 

        params.append(predictors_from_surrogate[model].parameters())
        params.append(predictors_to_surrogate[model].parameters()) 

        if dataset not in loaders:
            loaders[dataset] = batch_gen(torch.utils.data.DataLoader(datasets[dataset], 
                                                        batch_size = args.batch_size,
                                                        shuffle = True, 
                                                        num_workers = args.num_workers,
                                                        drop_last = True))
            



    #prepare optimizer
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(itertools.chain(*params), 
                                    lr = args.learning_rate, 
                                    momentum = args.momentum,
                                    weight_decay = args.weight_decay)
    elif args.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(itertools.chain(*params),
                                      lr = args.learning_rate)
    
    
    #prepare bandit (if enabled)
    if args.bandit:
        if args.separate_arms:
            arms = [(idx, model, dataset, 'to_surrogate') for idx, (model, dataset) in enumerate(zip(args.pretrained_models, args.pretrained_datasets))] + \
                [(idx, model, dataset, 'from_surrogate') for idx, (model, dataset) in enumerate(zip(args.pretrained_models, args.pretrained_datasets))]
        else:
            arms = [(idx, model, dataset, 'both') for idx, (model, dataset) in enumerate(zip(args.pretrained_models, args.pretrained_datasets))]
        
        bandit = Discounted_UCB1_tuned(arms, gamma = args.bandit_gamma)


    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = args.step_size, gamma = args.lr_decay)

    losses = []

    torch.cuda.synchronize()
    start_time = time.time()

    for step in range(args.steps):
        optimizer.zero_grad()

        total_loss = 0.

        if args.bandit:
            idx, model, dataset, direction = bandit.select()
            sim = update(idx, model, dataset, direction) #note that reward is opposite to loss
            bandit.update((idx, model, dataset, direction), reward = 1.0 - sim)

            total_loss = total_loss - sim

        else:
            loss_per_candiate = []
            for idx, (model, dataset) in enumerate(zip(args.pretrained_models, args.pretrained_datasets)):
                if '+' in dataset:
                    input_original, input_surrogate, _ = next(loaders[dataset])
                    input_original, input_surrogate = input_original.to(device), input_surrogate.to(device)
                else:
                    input, _ = next(loaders[dataset])
                    input = input.to(device)
                    input_original = input
                    input_surrogate = input
                with torch.no_grad():
                    original_feature = models[model](input_original)['feature']

                surrogate_feature = surrogate(input_surrogate, surrogate.get_embeddings(idx))

                predicted_surrogate_feature = predictors_to_surrogate[model](original_feature)
                predicted_original_feature = predictors_from_surrogate[model](surrogate_feature)

                if args.min_sim:
                    sim_to_surrogate = F.cosine_similarity(surrogate_feature, predicted_surrogate_feature).mean()
                    sim_from_surrogate = F.cosine_similarity(original_feature, predicted_original_feature).mean()
                    loss = - torch.minimum(sim_to_surrogate, sim_from_surrogate)
                else:
                    loss = - (F.cosine_similarity(surrogate_feature, predicted_surrogate_feature) + 
                            F.cosine_similarity(original_feature, predicted_original_feature)).mean()
            
                total_loss = total_loss + loss.item()
                loss_per_candiate.append(loss.item())
            
                loss.backward()

        optimizer.step()
        scheduler.step()

        losses.append(total_loss)
        if (step + 1) % args.log_freq == 0:
            if args.bandit:
                print ("[step: %d][loss: %.4f]"%(step + 1, total_loss))
                bandit.print()
            else:
                print ("[step: %d][loss: %.4f][avg loss: %.4f]"%(step + 1, total_loss, total_loss / n_pretrain_model))
                print ("[loss per candiate: ", loss_per_candiate, ']')
            sys.stdout.flush()


        if (step + 1) % args.save_freq == 0 or step + 1 == args.steps:
            if step + 1 == args.steps:
                torch.cuda.synchronize()
                end_time = time.time()
                print ("total training time: %.4f"%(end_time - start_time, ))

            torch.save({
                    'args': args,
                    'losses': losses,
                    'surrogate': surrogate.state_dict()
                }
                , os.path.join(save_dir, str(step+1) + '.t7'))
            print ('checkpoint saved at ' + os.path.join(save_dir, str(step+1) + '.t7'))
            sys.stdout.flush()
    




        



            



