import torch
import torch.nn as nn

from torch.utils.data import TensorDataset, DataLoader
from set_model import MeanPooling
from algorithms import lr_mae
from models.get_model import get_model

from transformers import ViTMAEForPreTraining, ViTMAEConfig, get_cosine_schedule_with_warmup
import kornia.augmentation as K
import timm.optim.optim_factory as optim_factory

import math

def run(args, train_loader, meta_val_ds, meta_test_ds, device, logger):
    # data    
    rnd_resizedcrop = K.RandomResizedCrop(
        size=(84, 84), scale=(0.2, 1.0), ratio=(0.75, 1.3333333333333333),
        resample=2,
        p=1.0, same_on_batch=False
    )
    rnd_hflip = K.RandomHorizontalFlip(
        p=0.5, same_on_batch=False
    )
    normalize = K.Normalize(
        (.4713, .4503, .4039), (0.2750, 0.2661, 0.2824), p=1.
    )
    transform = nn.Sequential(
        rnd_resizedcrop,
        rnd_hflip,
        normalize
    )

    config = ViTMAEConfig(
        hidden_size = 512,
        num_hidden_layers = 8,
        num_attention_heads = 8,
        intermediate_size = 512,
        hidden_act = 'gelu',
        hidden_dropout_prob = 0.0,
        attention_probs_dropout_prob = 0.0,
        initializer_range = 0.02,
        layer_norm_eps = 1e-12,
        is_encoder_decoder = False,
        image_size = 84,
        patch_size = 6,
        num_channels = 3,
        qkv_bias = True,
        decoder_num_attention_heads = 8,
        decoder_hidden_size = 128,
        decoder_num_hidden_layers = 3,
        decoder_intermediate_size = 128,
        mask_ratio = 0.75,
        norm_pix_loss = True
    )

    model = ViTMAEForPreTraining(config).to(device)    

    # model
    encoder = model.vit    
    decoder = MeanPooling()    
    head = model.decoder

    logger.register_model_to_save(encoder, 'encoder')
    logger.register_model_to_save(decoder, 'decoder')
    logger.register_model_to_save(head, 'head')
    
    lr = args.lr * args.batch_size / 256
    param_groups = optim_factory.param_groups_weight_decay(model, 0.05)
    optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.95))
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, 40*len(train_loader), args.training_epochs*len(train_loader)
    )

    step = 1
    for epoch in range(1, args.training_epochs+1):
        for data in train_loader:
            encoder.train(), decoder.train(), head.train()
            optimizer.zero_grad()
            
            # data augmentation
            data = data.to(device)
            data = transform(data) 

            output = model(data)
            loss = output.loss
            loss.backward()
            optimizer.step()

            logger.meter('train', 'ins vs ins loss', loss)
            step += 1

            scheduler.step()

            if args.debug:
                break
        
        if epoch % args.test_every == 0:
            args.support = 5
            mean, ci = lr_mae.run(args, args.val_episodes, encoder, decoder, meta_val_ds, device)        
            logger.meter('5shot-val', 'accuracy', mean)
            logger.meter('5shot-val', 'ci', ci)
        
        if epoch == args.training_epochs:
            for support in [1, 5, 20, 50]:
                args.support = support
                mean, ci = lr_mae.run(args, args.test_episodes, encoder, decoder, meta_test_ds, device)
                logger.meter(f'{support}shot-test', 'accuracy', mean)
                logger.meter(f'{support}shot-test', 'ci', ci)

        logger.step()
