#!/usr/bin/env python3

import argparse
import numpy as np
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

def pairwise_distances_logits(a, b):
    n = a.shape[0]
    m = b.shape[0]
    logits = -((a.unsqueeze(1).expand(n, m, -1) -
                b.unsqueeze(0).expand(n, m, -1))**2).sum(dim=2)
    return logits


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


class Convnet(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = l2l.vision.models.CNN4Backbone(
            hidden_size=hid_dim,
            channels=x_dim,
            max_pool=True,
       )
        self.out_channels = 1600

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

def fast_adapt(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    if device is None:
        device = model.device()
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    sort = torch.sort(labels)
    data = data.squeeze(0)[sort.indices].squeeze(0)
    labels = labels.squeeze(0)[sort.indices].squeeze(0)

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)
    return loss, acc

def fast_adapt_c(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    #if device is None:
    #    device = model.device()
    data, labels = batch
    #data = data.to(device)
    #labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    #print(data.size(),labels.size())#torch.Size([1, 480, 3, 84, 84]) torch.Size([1, 480])
    #sort = torch.sort(labels)
    #data = data.squeeze(0)[sort.indices].squeeze(0)
    #labels = labels.squeeze(0)[sort.indices].squeeze(0)
    #print(data.size(),labels.size())#torch.Size([480, 3, 84, 84]) torch.Size([480])
    #ss
    

    # Compute support and query embeddings
    embeddings = model(data)
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    
    for offset in range(shot):
        support_indices[selection + offset] = True
    query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    query = embeddings[query_indices]
    labels = labels[query_indices].long()

    logits = pairwise_distances_logits(query, support)
    loss = F.cross_entropy(logits, labels)
    acc = accuracy(logits, labels)

    support=torch.mean(support,1,)#ways,d
    emb=support.reshape(1,-1)
    return loss, acc, emb

def fast_adapt_complete(model, batch, ways, shot, query_num, metric=None, device=None):
    if metric is None:
        metric = pairwise_distances_logits
    #if device is None:
    #    device = model.device()
    data, labels = batch
    #data = data.to(device)
    #labels = labels.to(device)
    n_items = shot * ways

    # Sort data samples by labels
    # TODO: Can this be replaced by ConsecutiveLabels ?
    #print(data.size(),labels.size())#torch.Size([1, 480, 3, 84, 84]) torch.Size([1, 480])
    #sort = torch.sort(labels)
    #data = data.squeeze(0)[sort.indices].squeeze(0)
    #labels = labels.squeeze(0)[sort.indices].squeeze(0)
    #print(data.size(),labels.size())#torch.Size([480, 3, 84, 84]) torch.Size([480])
    #ss
    

    # Compute support and query embeddings
    shot=shot + query_num
    query_num=0
    embeddings = model(data)
    #print(data.size(),embeddings.size())#torch.Size([480, 3, 84, 84]) torch.Size([480, 1600])
    support_indices = np.zeros(data.size(0), dtype=bool)
    selection = np.arange(ways) * (shot + query_num)
    
    for offset in range(shot):
        support_indices[selection + offset] = True
    #query_indices = torch.from_numpy(~support_indices)
    support_indices = torch.from_numpy(support_indices)
    support = embeddings[support_indices]
    support = support.reshape(ways, shot, -1).mean(dim=1)
    #print(ways,shot)
    #query = embeddings[query_indices]
    #labels = labels[query_indices].long()

    #logits = pairwise_distances_logits(query, support)
    #loss = F.cross_entropy(logits, labels)
    #acc = accuracy(logits, labels)

    support=torch.mean(support,1,)#ways,d
    emb=support.reshape(1,-1)
    return None, None, emb




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--max-epoch', type=int, default=50000)
    parser.add_argument('--shot', type=int, default=1)#1,5
    parser.add_argument('--test-way', type=int, default=5)
    parser.add_argument('--test-shot', type=int, default=1)#1,5
    parser.add_argument('--test-query', type=int, default=30)
    parser.add_argument('--train-query', type=int, default=5)#1,5
    parser.add_argument('--train-way', type=int, default=5)#30,20
    parser.add_argument('--gpu', default=1)
    parser.add_argument('--meta-batch-size', default=32)
    parser.add_argument('--inc', default=0.1)
    parser.add_argument('--outc', default=0.1)
    parser.add_argument('--times', default=1)
    args = parser.parse_args()
    test_every=10
    best_iter=0
    best_test_test=0
    p_test=0
    print(args)
    lr_step=200#20*15/args.train_query
    device = torch.device('cpu')
    if args.gpu and torch.cuda.device_count():
        print("Using gpu")
        torch.cuda.manual_seed(43)
        device = torch.device('cuda')

    model = Convnet()
    model.to(device)
    pts=torchvision.transforms.ToTensor()
    path_data = '~/data'
    train_dataset = l2l.vision.datasets.TieredImagenet(
        root=path_data, mode='train', download=True,transform=pts)
    valid_dataset = l2l.vision.datasets.TieredImagenet(
        root=path_data, mode='validation', download=True,transform=pts)
    test_dataset = l2l.vision.datasets.TieredImagenet(
        root=path_data, mode='test', download=True,transform=pts)

    train_dataset = l2l.data.MetaDataset(train_dataset)
    train_transforms = [
        NWays(train_dataset, args.train_way),
        KShots(train_dataset, args.train_query + args.shot),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset, task_transforms=train_transforms)
    train_loader = DataLoader(train_tasks, pin_memory=True, shuffle=True)

    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    valid_transforms = [
        NWays(valid_dataset, args.test_way),
        KShots(valid_dataset, args.test_query + args.test_shot),
        LoadData(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(
        valid_dataset,
        task_transforms=valid_transforms,
        num_tasks=200,
    )
    valid_loader = DataLoader(valid_tasks, pin_memory=True, shuffle=True)

    test_dataset = l2l.data.MetaDataset(test_dataset)
    test_transforms = [
        NWays(test_dataset, args.test_way),
        KShots(test_dataset, args.test_query + args.test_shot),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(
        test_dataset,
        task_transforms=test_transforms,
        num_tasks=2000,
    )
    test_loader = DataLoader(test_tasks, pin_memory=True, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=int(lr_step), gamma=0.5)
    best_valid=0
    test_best_valid=0
    best_test=0
    for epoch in range(1, args.max_epoch + 1):
        model.train()

        loss_ctr = 0
        n_loss = 0
        n_acc = 0

        for e in range(100):
            optimizer.zero_grad()
            batch_emb_matrix=0
            for task in range(args.meta_batch_size):
                batch = next(iter(train_loader))
                loss=0
                acc=0
                task_emb_matrix=0
                for i in range(args.times):
                    data, labels = batch
                    data = data.to(device)
                    labels = labels.to(device)
                    sort = torch.sort(labels)
                    data = data.squeeze(0)[sort.indices].squeeze(0)
                    labels = labels.squeeze(0)[sort.indices].squeeze(0)
                    for c in range(args.train_way):
                        t=data[c*(args.shot + args.train_query):(c+1)*(args.shot + args.train_query)]
                        data[c*(args.shot + args.train_query):(c+1)*(args.shot + args.train_query)]=t[torch.randperm(args.shot + args.train_query)]
                    batch=[data,labels]
                    losst, acct, partial_emb = fast_adapt_c(model,
                                        batch,
                                        args.train_way,
                                        args.shot,
                                        args.train_query,
                                        metric=pairwise_distances_logits,
                                        device=device)
                    
                    if i==0:
                        task_emb_matrix=torch.zeros((args.times,partial_emb.size(-1)))
                    task_emb_matrix[i]=partial_emb.squeeze()

                    loss+=losst
                    acc+=acct

                task_emb_matrix=nn.functional.normalize(task_emb_matrix, p=2.0, dim=-1) 
                _, _, task_emb = fast_adapt_complete(model,
                                        batch,
                                        args.train_way,
                                        args.shot,
                                        args.train_query,
                                        metric=pairwise_distances_logits,
                                        device=device)
                task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)
                task_emb_matrix=task_emb_matrix.to(device) 
                in_cosine=torch.sum(task_emb_matrix@torch.transpose(task_emb,0,1))
                loss-=args.inc*in_cosine
                loss/=args.times
                loss/=args.meta_batch_size
                loss.backward()
                n_loss += loss.item()
                acc=acc/(args.times*args.meta_batch_size)
                n_acc += acc

                _, _, task_emb = fast_adapt_complete(model,
                                        batch,
                                        args.train_way,
                                        args.shot,
                                        args.train_query,
                                        metric=pairwise_distances_logits,
                                        device=device)
                task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)
                if task==0:
                    batch_emb_matrix=torch.zeros((args.meta_batch_size,task_emb.size(-1)))
                batch_emb_matrix[task]=task_emb.squeeze()

            loss_ctr += 1

            out_cosine=torch.sum(torch.mm(batch_emb_matrix,torch.transpose(batch_emb_matrix,1,0)))
            out_cosine/=(args.meta_batch_size*args.meta_batch_size)
            (args.outc*out_cosine).backward()
            #loss.backward()
            optimizer.step()
        lr_scheduler.step()

        print('epoch {}, train, loss={:.4f} acc={:.4f}'.format(
            epoch, n_loss/loss_ctr, n_acc/loss_ctr))

        model.eval()

        loss_ctr = 0
        n_loss = 0
        n_acc = 0
        for i, batch in enumerate(valid_loader):
            loss, acc = fast_adapt(model,
                                   batch,
                                   args.test_way,
                                   args.test_shot,
                                   args.test_query,
                                   metric=pairwise_distances_logits,
                                   device=device)

            loss_ctr += 1
            n_loss += loss.item()
            n_acc += acc
        
        print('epoch {}, val, loss={:.4f} acc={:.4f}, best_valid_previous={:.4f}'.format(
            epoch, n_loss/loss_ctr, n_acc/loss_ctr,best_valid))
        if best_valid<=n_acc/loss_ctr or (epoch>150 and epoch%10==0):  
            if best_valid<=n_acc/loss_ctr:
                best_valid=n_acc/loss_ctr
            loss_ctr = 0
            n_loss = 0
            n_acc = 0  
            for i, batch in enumerate(test_loader, 1):
                loss, acc = fast_adapt(model,
                                    batch,
                                    args.test_way,
                                    args.test_shot,
                                    args.test_query,
                                    metric=pairwise_distances_logits,
                                    device=device)
                loss_ctr += 1
                n_acc += acc
            test_best_valid=n_acc/loss_ctr * 100
            if test_best_valid>=best_test:
                best_test=test_best_valid
            if epoch%test_every==0:
                p_test=test_best_valid
                if test_best_valid>=best_test_test:
                    best_test_test=test_best_valid
                    best_iter=epoch
            
        
        elif epoch%test_every==0:
            loss_ctr = 0
            n_loss = 0
            n_acc = 0  
            for i, batch in enumerate(test_loader, 1):
                loss, acc = fast_adapt(model,
                                    batch,
                                    args.test_way,
                                    args.test_shot,
                                    args.test_query,
                                    metric=pairwise_distances_logits,
                                    device=device)
                loss_ctr += 1
                n_acc += acc
            test_test=n_acc/loss_ctr * 100
            if test_test>=best_test_test:
                best_test_test=test_test
                best_iter=epoch
            p_test=test_test
        print('test at best valid= {:.4f}, best test= {:.4f}, previous test= {:.4f}, best test test= {:.4f} at iter {}'.format(test_best_valid,best_test,p_test,best_test_test,best_iter))


            

    loss_ctr = 0
    n_acc = 0

    for i, batch in enumerate(test_loader, 1):
        loss, acc = fast_adapt(model,
                               batch,
                               args.test_way,
                               args.test_shot,
                               args.test_query,
                               metric=pairwise_distances_logits,
                               device=device)
        loss_ctr += 1
        n_acc += acc
        
    if (n_acc/loss_ctr)>=best_test:
            best_test=n_acc/loss_ctr
    print('final test= {:.4f} ,best test= {:.4f} '.format(n_acc/loss_ctr,best_test))
        
