import torch
import argparse
import surrogate_models
import pretrained_models
from datasets import create_dataset
import itertools
from utils.util import batch_gen
import torch.nn.functional as F
import os
import sys
import torch.nn.utils.parametrize as parametrize
import timm

class spectral_norm_L1(torch.nn.Module):
    def forward(self, weight):
        #weight: out_dim * in_dim
        return weight / weight.abs().sum(dim=0).max().clamp(1e-6)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='using a surrogate/pretrained model on downstream task')

    parser.add_argument('--name', type=str, default='noname', help='the name of the trial')

    #surrogate model
    parser.add_argument('--surrogate_type', type=str, default='FeatureWeighting+resnet18', help='architecture for the surrogate model')
    parser.add_argument('--load_state_dict', action='store_true')

    #dataset
    parser.add_argument('--dataset_train', type=str, default='CIFAR10_train', 
                        help = 'dataset for the downstream task')
    parser.add_argument('--dataset_val', type=str, default='CIFAR10_val', 
                        help = 'dataset for the downstream task')
    parser.add_argument('--n_class', type=int, default=10,
                        help = 'the number of classes to be predicted')

    #the model to be used as the feature extractor
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--selection', action='store_true', 
                       help = 'Model Selection: \
                        enable to tune the embedding with the given surrogate model fixed; \
                        mutually exclusive with --evaluation')
    group.add_argument('--evaluation', action='store_true',
                       help = 'Evaluating a pretrained model: \
                        enable to tune the prediction head with the given pre-trained model fixed; \
                        mutually exclusive with --selection')
    
    parser.add_argument('--model', type=str, 
                        default='', 
                        help='with --selection: path to the surrogate model; with --evaluation: identifier of the pre-trained model')
    

    parser.add_argument('--L1_reg', type=float, default=0.,
                        help = 'L1 regularization for the embedding')

    parser.add_argument('--spectral_norm', action='store_true',
                        help = 'enable spectral_norm for the predictor (for selection)')
    

    parser.add_argument('--spectral_norm_L1', action='store_true',
                        help = 'enable spectral_norm_L1 for the predictor (for selection)')


    #optimization
    parser.add_argument('--batch_size', type=int, default = 128)
    parser.add_argument('--epochs', type=int, default = 30)
    parser.add_argument('--num_workers', type=int, default=4)

    parser.add_argument('--optimizer', type=str, default='SGD', 
                        help='type of the optimizer, in [SGD]')
    parser.add_argument('--learning_rate', type=float, default=0.1)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=5e-4)

    parser.add_argument('--lr_decay', type=float, default=0.1)
    parser.add_argument('--step_size', type=int, default=10)

    parser.add_argument('--save_freq', type=int, default=10000)

    parser.add_argument('--overwrite', action='store_true')


    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    save_dir = './checkpoint_downstream/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_dir = os.path.join(save_dir, args.name)
    if os.path.exists(save_dir):
        if not args.overwrite:
            raise Exception(save_dir + ' already exists; activate --overwrite to overwrite')
    else:
        os.makedirs(save_dir)

    
    #prepare model
    params = []

    if args.selection:

        if args.load_state_dict:
            ckp = torch.load(args.model)
            model = surrogate_models.create(surrogate_type = ckp['args'].surrogate_type, temperature = ckp['args'].temperature)
            model.init_embeddings(len(ckp['args'].pretrained_models))
            model.load_state_dict(ckp['surrogate'])
        else:
            model = torch.load(args.model)['surrogate']

        model.init_selection_embedding()
        model.to(device)
        model.eval()
        
        params.append({'params':model.selection_parameters(), 'weight_decay':0})

    elif args.evaluation:
        model = pretrained_models.get_feature_extractor(args.model)
        model.to(device)
        model.eval()

    else:
        raise Exception('must enable either --selection or --evaluation')
    
    predictor = torch.nn.Linear(model.output_dim, args.n_class)
    if args.spectral_norm:
        predictor = torch.nn.utils.parametrizations.spectral_norm(predictor, 'weight')
    if args.spectral_norm_L1:
        parametrize.register_parametrization(predictor, 'weight', spectral_norm_L1())

    predictor.to(device)
    params.append({'params':predictor.parameters()})
    
    #prepare datasets
    train_set = create_dataset(args.dataset_train)
    val_set = create_dataset(args.dataset_val)

    #prepare loaders
    train_loader = torch.utils.data.DataLoader(train_set, 
                                               batch_size = args.batch_size,
                                               shuffle = True, 
                                               num_workers = args.num_workers,
                                               drop_last = True)
    val_loader = torch.utils.data.DataLoader(val_set, 
                                               batch_size = args.batch_size,
                                               shuffle = False, 
                                               num_workers = args.num_workers,
                                               drop_last = False)
    
    #prepare optimizer
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(params, 
                                lr = args.learning_rate, 
                                momentum = args.momentum,
                                weight_decay = args.weight_decay)
    elif args.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(params,
                                      lr = args.learning_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = args.step_size, gamma = args.lr_decay)


    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(args.epochs):

        correct = 0
        count = 0
        total_loss = 0.
        batch_count = 0
        
        for idx, (input, target) in enumerate(train_loader):

            input, target = input.to(device), target.to(device)

            if args.selection:
                v = model.get_selection_embedding()
                feature = model(input, v, selection = True)

            elif args.evaluation:
                with torch.no_grad():
                    feature = model(input)['feature']
            else:
                raise Exception('must enable either --selection or --evaluation')
            
            logits = predictor(feature)

            loss = criterion(logits, target)
            if args.selection and args.L1_reg > 0:
                loss = loss + args.L1_reg * v.sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(logits, 1)
            correct += predicted.eq(target).sum().cpu()
            total_loss += loss.item()
            count += target.size(0)
            batch_count += 1

        print ('[train] epoch %d | loss: %.3f | acc: %.3f%% (%d/%d)' % 
               (epoch + 1, total_loss / batch_count, 100. * float(correct) / count, correct, count))
        
        scheduler.step()

        #validation
        correct = 0
        count = 0
        total_loss = 0.
        batch_count = 0
        
        for idx, (input, target) in enumerate(val_loader):

            input, target = input.to(device), target.to(device)

            with torch.no_grad():
                if args.selection:
                    v = model.get_selection_embedding()
                    feature = model(input, v, selection = True)

                elif args.evaluation:
                    feature = model(input)['feature']
                else:
                    raise Exception('must enable either --selection or --evaluation')
            
                logits = predictor(feature)

                loss = criterion(logits, target)

            _, predicted = torch.max(logits, 1)
            correct += predicted.eq(target).sum().cpu()
            total_loss += loss.item()
            count += target.size(0)
            batch_count += 1

        print ('[eval] epoch %d | loss: %.3f | acc: %.3f%% (%d/%d)' % 
               (epoch + 1, total_loss / batch_count, 100. * float(correct) / count, correct, count))

        if (epoch + 1) % args.save_freq == 0 or epoch + 1 == args.epochs:
            if args.selection:
                v = model.get_selection_embedding()
                save = {'args': args, 'embedding': v}
            else:
                save = {'args': args, 'predictor': predictor}

            torch.save(save, os.path.join(save_dir, str(epoch+1) + '.t7'))
            print ('checkpoint saved at ' + os.path.join(save_dir, str(epoch+1) + '.t7'))
            sys.stdout.flush()





    
