#!/usr/bin/env python3

"""
Demonstrates how to:
    * use the MAML wrapper for fast-adaptation,
    * use the benchmark interface to load mini-ImageNet, and
    * sample tasks and split them in adaptation and evaluation sets.

To contrast the use of the benchmark interface with directly instantiating mini-ImageNet datasets and tasks, compare with `protonet_miniimagenet.py`.
"""

import random
import numpy as np
import os
import torch
from torch import nn, optim

import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)

os.environ['CUDA_VISIBLE_DEVICES'] = '2'


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def fast_adapt_c(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)

    
    # Separate data into adaptation/evalutation sets
    #print(data.size(0),shots*ways)
    #print(labels)
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    #print(evaluation_indices,adaptation_indices)
    #ss
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    params=[p for p in learner.module.parameters() if p.requires_grad]
    partial_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                partial_emb=params[pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
    #print(task_partial_emb.size())
    #ss
    #print(params)
    # Evaluate the adapted model

    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy,partial_emb

def fast_adapt_complete(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)


    adaptation_data, adaptation_labels = data, labels
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)
    params=[p for p in learner.module.parameters() if p.requires_grad]
    task_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                task_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                task_emb=params[pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[pi].reshape(1,-1)),-1)
    return None, None,task_emb


def main(
        ways=5,
        shots=5,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=0,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='~/data',
    )

    # Create model
    model = l2l.vision.models.ResNet12(ways)
    model.to(device)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=True)############################
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')
    best_valid=0
    test_best_valid=0
    best_test=0
    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        batch_emb_matrix=0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            learner = maml.clone()
            batch = tasksets.train.sample()
            evaluation_error, evaluation_accuracy ,partial_emb = fast_adapt_c(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device,pgs=2)
            learner = maml.clone()
            _, _,task_emb = fast_adapt_complete(batch,
                                            learner,
                                            loss,
                                            adaptation_steps,
                                            shots,
                                            ways,
                                            device,pgs=2)
            task_emb_matrix=nn.functional.normalize(partial_emb, p=2.0, dim=-1)  
            task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)
            in_cosine=torch.sum(task_emb_matrix@torch.transpose(task_emb,0,1))
            
            evaluation_error-=0.1*in_cosine

            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            learner = maml.clone()
            _, _,task_emb = fast_adapt_complete(batch,
                                            learner,
                                            loss,
                                            adaptation_steps,
                                            shots,
                                            ways,
                                            device,pgs=2)
            task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)
            if task==0:
                batch_emb_matrix=torch.zeros((meta_batch_size,task_emb.size(-1)))
            batch_emb_matrix[task]=task_emb.squeeze()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Privious Best Meta Valid Accuracy', best_valid / meta_batch_size)
        print('Privious Meta Test Accuracy at Best Meta Valid Accuracy', test_best_valid / meta_batch_size)
        print('Privious Best Meta test Accuracy', best_test / meta_batch_size)

        if meta_valid_accuracy>=best_valid:
            best_valid=meta_valid_accuracy
            meta_test_error = 0.0
            meta_test_accuracy = 0.0
            for task in range(meta_batch_size):
                # Compute meta-testing loss
                learner = maml.clone()
                batch = tasksets.test.sample()
                data,label=batch
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                learner,
                                                                loss,
                                                                adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device)
                meta_test_error += evaluation_error.item()
                meta_test_accuracy += evaluation_accuracy.item()
            #print('Meta Test Error', meta_test_error / meta_batch_size)
            print('Meta Test Accuracy at This Epoch', meta_test_accuracy / meta_batch_size)
            print('Previous Best Meta Test Accuracy', best_test / meta_batch_size)
            test_best_valid=meta_test_accuracy
            if meta_test_accuracy>=best_test:
                best_test=meta_test_accuracy

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        batch_emb_matrix_t=torch.transpose(batch_emb_matrix,1,0)
        out_cosine=torch.sum(torch.mm(batch_emb_matrix,batch_emb_matrix_t))
        out_cosine/=(meta_batch_size*meta_batch_size)
        (0.1*out_cosine).backward()
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        learner = maml.clone()
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)


if __name__ == '__main__':
    main()
