#!/usr/bin/env python3

"""
Demonstrates how to:
    * use the MAML wrapper for fast-adaptation,
    * use the benchmark interface to load mini-ImageNet, and
    * sample tasks and split them in adaptation and evaluation sets.

To contrast the use of the benchmark interface with directly instantiating mini-ImageNet datasets and tasks, compare with `protonet_miniimagenet.py`.
"""

import random
import numpy as np
import os
import torch
from torch import nn, optim

import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)

os.environ['CUDA_VISIBLE_DEVICES'] = '5'


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

def fast_adapt_c(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)

    
    # Separate data into adaptation/evalutation sets
    #print(data.size(0),shots*ways)
    #print(labels)
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    #print(evaluation_indices,adaptation_indices)
    #ss
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)
    '''
    (['features.0.normalize.weight', 'features.0.normalize.bias', 'features.0.normalize.running_mean', 
    'features.0.normalize.running_var', 'features.0.normalize.num_batches_tracked', 'features.0.conv.weight',
      'features.0.conv.bias', 'features.1.normalize.weight', 'features.1.normalize.bias', 'features.1.normalize.running_mean',
        'features.1.normalize.running_var', 'features.1.normalize.num_batches_tracked', 'features.1.conv.weight', 
        'features.1.conv.bias', 'features.2.normalize.weight', 'features.2.normalize.bias', 'features.2.normalize.running_mean',
          'features.2.normalize.running_var', 'features.2.normalize.num_batches_tracked', 'features.2.conv.weight', 
          'features.2.conv.bias', 'features.3.normalize.weight', 'features.3.normalize.bias', 'features.3.normalize.running_mean', 
          'features.3.normalize.running_var', 'features.3.normalize.num_batches_tracked', 'features.3.conv.weight', 
          'features.3.conv.bias', 'classifier.weight', 'classifier.bias'])'''
    params=[p for p in learner.module.parameters() if p.requires_grad]
    partial_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                partial_emb=params[pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
    #print(task_partial_emb.size())
    #ss
    #print(params)
    # Evaluate the adapted model

    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy,partial_emb

def fast_adapt_complete(batch, learner, loss, adaptation_steps, shots, ways, device,pgs):
    data, labels = batch
    #print(batch)
    #ss
    data, labels = data.to(device), labels.to(device)

    
    # Separate data into adaptation/evalutation sets
    #print(data.size(0),shots*ways)
    #print(labels)
    '''adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    #print(evaluation_indices,adaptation_indices)
    #ss
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]'''
    adaptation_data, adaptation_labels = data, labels
    # Adapt the model
    for step in range(adaptation_steps):
        adaptation_error = loss(learner(adaptation_data), adaptation_labels)
        learner.adapt(adaptation_error)
    params=[p for p in learner.module.parameters() if p.requires_grad]
    task_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                task_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                task_emb=params[pi].reshape(1,-1)
            else:
                task_emb=torch.cat((task_emb,params[pi].reshape(1,-1)),-1)
    #print(task_partial_emb.size())
    #ss
    #print(params)
    # Evaluate the adapted model

    #predictions = learner(evaluation_data)
    #evaluation_error = loss(predictions, evaluation_labels)
    #evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return None, None,task_emb


def main(
        ways=5,
        shots=1,
        meta_lr=0.003,
        fast_lr=0.5,
        meta_batch_size=32,
        adaptation_steps=1,
        num_iterations=60000,
        cuda=True,
        seed=0,#42
        times=1,
        in_c=0.1,
        out_c=0.1,
        pgs=2,
        first_order=True,
):  
    #device_ids = [0,1,2]
    test_every=1000
    best_iter=0
    best_test_test=0
    p_test=0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets('mini-imagenet',
                                                  train_samples=2*shots,
                                                  train_ways=ways,
                                                  test_samples=2*shots,
                                                  test_ways=ways,
                                                  root='~/data',
    )

    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    #model = l2l.vision.models.ResNet12(ways)
    model.to(device)
    #model = torch.nn.DataParallel(model, device_ids=device_ids,output_device=0)
    maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=first_order)#False
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    best_valid=0
    test_best_valid=0
    best_test=0
    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        evaluation_error_batch=0
        batch_emb_matrix=0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            evaluation_error=0
            evaluation_accuracy=0
            task_emb_matrix=0
            #learner = maml.clone()
            batch = tasksets.train.sample()
            for i in range(times):
                if i>0:
                    for c in range(ways):
                        data,label=batch
                        t=data[c*shots*2:(c+1)*shots*2]
                        data[c*shots*2:(c+1)*shots*2]=t[torch.randperm(shots*2)]
                    batch=[data,label]
                learner = maml.clone()
                partial_evaluation_error, partial_evaluation_accuracy,partial_emb = fast_adapt_c(batch,
                                                                learner,
                                                                loss,
                                                                adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device,pgs)
                if i==0:
                    task_emb_matrix=torch.zeros((times,partial_emb.size(-1)))
                task_emb_matrix[i]=partial_emb.squeeze()

                evaluation_error+=partial_evaluation_error
                evaluation_accuracy+=partial_evaluation_accuracy
            #print(task_emb_matrix.size())#16,dim
            task_emb_matrix=nn.functional.normalize(task_emb_matrix, p=2.0, dim=-1)  
            
            learner = maml.clone()
            _, _,task_emb = fast_adapt_complete(batch,
                                            learner,
                                            loss,
                                            adaptation_steps,
                                            shots,
                                            ways,
                                            device,pgs)
            task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)

            #task_emb.to(device) 
            task_emb_matrix=task_emb_matrix.to(device) 
            #print(task_emb_matrix.device,task_emb.device)
            in_cosine=torch.sum(task_emb_matrix@torch.transpose(task_emb,0,1))
            
            evaluation_error-=in_c*in_cosine
            evaluation_error/=times
            evaluation_accuracy/=times
            #evaluation_error_batch+=evaluation_error
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()


            learner = maml.clone()
            _, _,task_emb = fast_adapt_complete(batch,
                                            learner,
                                            loss,
                                            adaptation_steps,
                                            shots,
                                            ways,
                                            device,pgs)
            task_emb=nn.functional.normalize(task_emb, p=2.0, dim=-1)
            if task==0:
                batch_emb_matrix=torch.zeros((meta_batch_size,task_emb.size(-1)))
            batch_emb_matrix[task]=task_emb.squeeze()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        batch_emb_matrix=batch_emb_matrix.to(device)#bz,d
        
       
        # Print some metrics
        
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)
        print('Privious Best Meta Valid Accuracy', best_valid / meta_batch_size)
        print('Privious Meta Test Accuracy at Best Meta Valid Accuracy', test_best_valid / meta_batch_size)
        print('Privious Best Meta test Accuracy', best_test / meta_batch_size)
        print('Privious Best Meta test Accuracy test', best_test_test / meta_batch_size, ' at epoch', best_iter, 'Privious Meta test Accuracy', p_test / meta_batch_size)

        if meta_valid_accuracy>=best_valid:
            best_valid=meta_valid_accuracy
            meta_test_error = 0.0
            meta_test_accuracy = 0.0
            for task in range(meta_batch_size):
                # Compute meta-testing loss
                learner = maml.clone()
                batch = tasksets.test.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                learner,
                                                                loss,
                                                                adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device)
                meta_test_error += evaluation_error.item()
                meta_test_accuracy += evaluation_accuracy.item()
            #print('Meta Test Error', meta_test_error / meta_batch_size)
            print('Meta Test Accuracy at This Epoch', meta_test_accuracy / meta_batch_size)
            print('Previous Best Meta Test Accuracy', best_test / meta_batch_size)
            test_best_valid=meta_test_accuracy
            if meta_test_accuracy>=best_test:
                best_test=meta_test_accuracy
        if iteration%test_every==0:
            meta_test_error = 0.0
            meta_test_accuracy = 0.0
            for task in range(meta_batch_size):
                # Compute meta-testing loss
                learner = maml.clone()
                batch = tasksets.test.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                learner,
                                                                loss,
                                                                adaptation_steps,
                                                                shots,
                                                                ways,
                                                                device)
                meta_test_error += evaluation_error.item()
                meta_test_accuracy += evaluation_accuracy.item()
            #print('Meta Test Error', meta_test_error / meta_batch_size)
            if best_test_test<=meta_test_accuracy:
                best_test_test=meta_test_accuracy
                best_iter=iteration
                torch.save(maml.state_dict(), 'best_test_in'+str(in_c)+'_out'+str(out_c)+'_shot'+str(shots)+'.pt')
            p_test=meta_test_accuracy
            print('Meta Test Accuracy at This Epoch', meta_test_accuracy / meta_batch_size)
            print('Previous Best Meta Test Accuracy', best_test_test / meta_batch_size, ' at epoch', best_iter)

        # Average the accumulated gradients and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
            
        batch_emb_matrix_t=torch.transpose(batch_emb_matrix,1,0)
        out_cosine=torch.sum(torch.mm(batch_emb_matrix,batch_emb_matrix_t))
        out_cosine/=(meta_batch_size*meta_batch_size)
        (out_c*out_cosine).backward()
        #evaluation_error_batch/=meta_batch_size
        #evaluation_error_batch+=(out_c*out_cosine)
        #evaluation_error_batch.step()
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        learner = maml.clone()
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)


if __name__ == '__main__':
    main()
