import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import TensorDataset, DataLoader
from set_model import MeanPooling
from algorithms import lr
from models.get_model import get_model
from utils import adjust_learning_rate

@torch.no_grad()
def dequeue_and_enqueue(queue, ptr, keys):
    batch_size = keys.shape[0]
    K = queue.shape[1]
    assert K % batch_size == 0 

    queue[:, ptr:ptr + batch_size].data.copy_(keys.T.data)
    ptr = (ptr + batch_size) % K
    return ptr

def ema(model, model_ema, t=0.999):
    for w, fw in zip(model.parameters(), model_ema.parameters()):
        fw.data.mul_(t).add_(w.data, alpha=1.-t)

def return_loss_fn(x, y):
    y = y.detach()
    x_norm = torch.norm(x, p=2, dim=1)
    y_norm = torch.norm(y, p=2, dim=1)
    mse = 2. - 2.*torch.sum(x*y, dim=1)/(x_norm*y_norm)
    return mse.mean()

def run(args, train_loader, meta_val_ds, meta_test_ds, device, logger):
    K, T, m = 65536, 0.07, 0.999
    # data    
    transform = args.transform
    
    # model
    encoder = get_model(args.model, args.img_size).to(device)
    last_hidden_size = encoder(
        torch.randn(1, 3, args.img_size, args.img_size).to(device)).shape[-1]

    encoder_ema = get_model(args.model, args.img_size).to(device)
    encoder_ema.load_state_dict(encoder.state_dict(), strict=True)
    
    decoder = MeanPooling()    

    logger.register_model_to_save(encoder, 'encoder')
    logger.register_model_to_save(decoder, 'decoder')
    
    optimizer = torch.optim.Adam(
        [{"params": encoder.parameters(), "lr": args.lr},
        {"params": decoder.parameters(), "lr": args.lr}]
    )

    queue = torch.randn((last_hidden_size, K), requires_grad=False).to(device)
    queue = F.normalize(queue, dim=0)
    ptr = 0

    step = 1
    for epoch in range(1, args.training_epochs+1):
        for data in train_loader:
            encoder.train(), decoder.train()
            optimizer.zero_grad()

            # lr scheduling
            if args.lr_scheduling:
                adjust_learning_rate(args, optimizer, train_loader, step)
            
            # data augmentation
            data = data.to(device)
            data_augmented = torch.stack([data]*2, dim=1)            
            data_augmented = data_augmented.reshape(-1, 3, args.pre_img_size, args.pre_img_size)
            data_augmented = transform(data_augmented)
            data = data_augmented.reshape(-1, 2, 3, args.img_size, args.img_size)   

            # feature
            q = encoder(data[:, 0, :, :, :])
            q = F.normalize(q, dim=1)
            
            # ema feature
            with torch.no_grad():
                # ema update
                ema(encoder, encoder_ema, t=m)

                k = encoder_ema(data[:, 1, :, :, :])
                k = F.normalize(k, dim=1)

            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
            l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()])
            logits = torch.cat([l_pos, l_neg], dim=1)
            logits /= T
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

            ptr = dequeue_and_enqueue(queue, ptr, k)

            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()

            logger.meter('train', 'ins vs ins loss', loss)
            step += 1

            if args.debug:
                break
            
        if epoch % args.test_every == 0:
            args.support = 5
            mean, ci = lr.run(args, args.val_episodes, encoder, decoder, meta_val_ds, device)        
            logger.meter('5shot-val', 'accuracy', mean)
            logger.meter('5shot-val', 'ci', ci)
        
        if epoch == args.training_epochs:
            for support in [1, 5, 20, 50]:
                args.support = support
                mean, ci = lr.run(args, args.test_episodes, encoder, decoder, meta_test_ds, device)
                logger.meter(f'{support}shot-test', 'accuracy', mean)
                logger.meter(f'{support}shot-test', 'ci', ci)

        logger.step()
            
