import torch
import torch.nn as nn

from torch.utils.data import TensorDataset, DataLoader
from set_model import get_set_model
from algorithms import lr_warmstart
from models.get_model import get_model
from utils import adjust_learning_rate

def return_loss_fn(x, t=0.5, eps=1e-8):
    n = torch.norm(x, p=2, dim=1, keepdim=True)
    x = (x @ x.t()) / (n * n.t()).clamp(min=eps) # 2N x 2N
    x = torch.exp(x / t) # 2N x 2N
    idx = torch.arange(x.size()[0])
    idx[::2] += 1
    idx[1::2] -= 1
    x = x[idx] # 2N x 2N
    x = x.diag() / (x.sum(0) - torch.exp(torch.tensor(1 / t)))
    return torch.mean(-torch.log(x))

def run(args, train_loader, meta_val_ds, meta_test_ds, device, logger):
    # data    
    transform = args.transform

    # model
    encoder = get_model(args.model, args.img_size).to(device)
    last_hidden_size = encoder(
        torch.randn(1, 3, args.img_size, args.img_size).to(device)).shape[-1]    
    
    # set model
    decoder = get_set_model(args, last_hidden_size, device)
    
    head = nn.Sequential(
        nn.Linear(last_hidden_size, last_hidden_size),
        nn.BatchNorm1d(last_hidden_size),
        nn.LeakyReLU(),
        nn.Linear(last_hidden_size, int(last_hidden_size/4)),
    ).to(device)

    logger.register_model_to_save(encoder, 'encoder')
    logger.register_model_to_save(decoder, 'decoder')
    logger.register_model_to_save(head, 'head')
    
    optimizer = torch.optim.Adam(
        [{"params": encoder.parameters(), "lr": args.lr},
        {"params": decoder.parameters(), "lr": args.lr},
        {"params": head.parameters(), "lr": args.lr}]
    )

    step = 1
    for epoch in range(1, args.training_epochs+1):
        for data in train_loader:
            encoder.train(), decoder.train(), head.train()
            optimizer.zero_grad()

            # lr scheduling
            if args.lr_scheduling:
                adjust_learning_rate(args, optimizer, train_loader, step)
            
            # data augmentation
            data = data.to(device)
            data_augmented = torch.stack([data]*args.repeat_augmentations, dim=1)            
            data_augmented = data_augmented.reshape(-1, 3, args.pre_img_size, args.pre_img_size)
            data_augmented = transform(data_augmented)
            data = data_augmented.reshape(-1, args.repeat_augmentations, 3, args.img_size, args.img_size)            
            d = data.size()

            # instance feature            
            features = encoder(
                data.reshape(d[0]*d[1], d[2], d[3], d[4]).contiguous()
            ).reshape(d[0], d[1], -1).contiguous()

            # ins vs ins loss
            ins_vs_ins_f = features[:, :2, :].contiguous().reshape(d[0]*2, -1).contiguous()
            ins_vs_ins_e = head(ins_vs_ins_f)
            ins_vs_ins_loss = return_loss_fn(ins_vs_ins_e)
                       
            # set vs set loss
            set_vs_set_f = decoder(
                features.reshape(d[0]*2, int(d[1]/2), last_hidden_size).contiguous()
            )
            set_vs_set_f = set_vs_set_f.reshape(d[0]*2, last_hidden_size).contiguous()
            set_vs_set_e = head(set_vs_set_f)
            set_vs_set_loss = return_loss_fn(set_vs_set_e)            

            loss =  ins_vs_ins_loss + args.beta*set_vs_set_loss

            loss.backward()
            optimizer.step()

            logger.meter('train', 'ins vs ins loss', ins_vs_ins_loss)
            logger.meter('train', 'set vs set loss', set_vs_set_loss)            
            step += 1

            if args.debug:
                break
        
        if epoch % args.test_every == 0:
            args.support = 5
            mean, ci = lr_warmstart.run(args, args.val_episodes, encoder, decoder, meta_val_ds, device)        
            logger.meter('5shot-val', 'accuracy', mean)
            logger.meter('5shot-val', 'ci', ci)
        
        if epoch == args.training_epochs:
            for support in [1, 5, 20, 50]:
                args.support = support
                mean, ci = lr_warmstart.run(args, args.test_episodes, encoder, decoder, meta_test_ds, device)
                logger.meter(f'{support}shot-test', 'accuracy', mean)
                logger.meter(f'{support}shot-test', 'ci', ci)

        logger.step()
