#!/usr/bin/env python3

import random
import copy
import numpy as np
import torch
import learn2learn as l2l
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'


def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def fast_adapt(batch, learner, adapt_opt, loss, adaptation_steps, shots, ways, batch_size, device,return_emb=False,pgs=2):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        idx = torch.randint(
            adaptation_data.size(0),
            size=(batch_size, )
        )
        adapt_X = adaptation_data[idx]
        adapt_y = adaptation_labels[idx]
        adapt_opt.zero_grad()
        error = loss(learner(adapt_X), adapt_y)
        error.backward()
        adapt_opt.step()

    # Evaluate the adapted model
    predictions = learner(evaluation_data)
    valid_error = loss(predictions, evaluation_labels)
    valid_error /= len(evaluation_data)
    valid_accuracy = accuracy(predictions, evaluation_labels)
    if not return_emb:
        return valid_error, valid_accuracy
    else:
        params=[p for p in learner.parameters() if p.requires_grad]
        partial_emb=0
        if pgs>0:
            pgs=int(pgs)
            for pi in range(pgs):
                if pi==0:
                    partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
                else:
                    partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
        else:
            for pi in range(len(params)):
                if pi==0:
                    partial_emb=params[pi].reshape(1,-1)
                else:
                    partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
        return valid_error, valid_accuracy,partial_emb

def fast_adapt_c(batch, learner, adapt_opt, loss, adaptation_steps, shots, ways, batch_size, device,pn,pgs):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)
    wi=torch.arange(ways)
    setsoff=wi*data.size(0)/ways
    setsoff=setsoff.unsqueeze(1).repeat(1,pn)
    setsoff=setsoff.reshape(-1)
    idx=torch.arange(pn)
    idx=torch.concat([idx]*ways,-1).int()
    idx+=setsoff.int()

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[idx] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]


    # Adapt the model
    for step in range(adaptation_steps):
        idx = torch.randint(
            adaptation_data.size(0),
            size=(batch_size, )
        )
        adapt_X = adaptation_data[idx]
        adapt_y = adaptation_labels[idx]
        adapt_opt.zero_grad()
        error = loss(learner(adapt_X), adapt_y)
        error.backward()
        adapt_opt.step()

    params=[p for p in learner.parameters() if p.requires_grad]
    partial_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                partial_emb=params[pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
    
    return partial_emb


def fast_adapt_complete(batch, learner, adapt_opt, loss, adaptation_steps, shots, ways, batch_size, device,pgs=2):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots*ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Adapt the model
    for step in range(adaptation_steps):
        idx = torch.randint(
            adaptation_data.size(0),
            size=(batch_size, )
        )
        adapt_X = adaptation_data[idx]
        adapt_y = adaptation_labels[idx]
        adapt_opt.zero_grad()
        error = loss(learner(adapt_X), adapt_y)
        error.backward()
        adapt_opt.step()

   
    params=[p for p in learner.parameters() if p.requires_grad]
    partial_emb=0
    if pgs>0:
        pgs=int(pgs)
        for pi in range(pgs):
            if pi==0:
                partial_emb=params[len(params)-pgs+pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[len(params)-pgs+pi].reshape(1,-1)),-1)
    else:
        for pi in range(len(params)):
            if pi==0:
                partial_emb=params[pi].reshape(1,-1)
            else:
                partial_emb=torch.cat((partial_emb,params[pi].reshape(1,-1)),-1)
    return partial_emb

def main(
        experiment='dev',
        problem='mini-imagenet',
        ways=5,
        train_shots=15,
        test_shots=5,
        meta_lr=1.0,
        meta_bsz=32,
        fast_lr=0.001,
        train_bsz=10,
        test_bsz=15,
        train_steps=8,
        test_steps=50,
        iterations=100000,
        test_interval=1000,
        save='',
        cuda=1,
        seed=42,
        valid_bz=100,
        test_bz=200,
        inc=0.01,
        outc=0.01,
        sub_n=5,
        pgs=2
):
    cuda = bool(cuda)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    train_tasks, valid_tasks, test_tasks = l2l.vision.benchmarks.get_tasksets(
        'mini-imagenet',
        train_samples=2*train_shots,
        train_ways=ways,
        test_samples=2*test_shots,
        test_ways=ways,
        root='~/data',
    )
    bestv=0
    bestt1=0
    besttv1=0
    bestt5=0
    besttv5=0
    # Create model
    model = l2l.vision.models.MiniImagenetCNN(ways)
    model.to(device)
    opt = torch.optim.SGD(model.parameters(), meta_lr)
    adapt_opt = torch.optim.Adam(model.parameters(), lr=fast_lr, betas=(0, 0.999))
    adapt_opt_state = adapt_opt.state_dict()
    loss = torch.nn.CrossEntropyLoss(reduction='mean')

    train_inner_errors = []
    train_inner_accuracies = []
    valid_inner_errors = []
    valid_inner_accuracies = []
    test_inner_errors = []
    test_inner_accuracies = []

    for iteration in range(iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        meta_test_error1 = 0.0
        meta_test_accuracy1 = 0.0
        meta_test_error5 = 0.0
        meta_test_accuracy5 = 0.0

        in_loss=0
        # zero-grad the parameters
        for p in model.parameters():
            p.grad = torch.zeros_like(p.data)

        for task in range(meta_bsz):
            # Compute meta-training loss
            learner = copy.deepcopy(model)
            adapt_opt = torch.optim.Adam(
                learner.parameters(),
                lr=fast_lr,
                betas=(0, 0.999)
            )
            adapt_opt.load_state_dict(adapt_opt_state)
            batch = train_tasks.sample()
            evaluation_error, evaluation_accuracy,task_emb = fast_adapt(batch,
                                                               learner,
                                                               adapt_opt,
                                                               loss,
                                                               train_steps,
                                                               train_shots,
                                                               ways,
                                                               train_bsz,
                                                               device,
                                                               return_emb=True,
                                                               pgs=pgs)
            adapt_opt_state = adapt_opt.state_dict()
            for p, l in zip(model.parameters(), learner.parameters()):
                p.grad.data.add_(-1.0, l.data)
            
            learner = copy.deepcopy(model)
            adapt_opt = torch.optim.Adam(
                learner.parameters(),
                lr=fast_lr,
                betas=(0, 0.999)
            )
            adapt_opt.load_state_dict(adapt_opt_state)
            partial_emb = fast_adapt_c(batch,
                                        learner,
                                        adapt_opt,
                                        loss,
                                        train_steps,
                                        train_shots,
                                        ways,
                                        train_bsz,
                                        device,
                                        pn=sub_n,
                                        pgs=pgs)
            adapt_opt_state = adapt_opt.state_dict()

            task_emb_matrix=torch.nn.functional.normalize(partial_emb, p=2.0, dim=-1)  
            task_emb=torch.nn.functional.normalize(task_emb, p=2.0, dim=-1)
            in_cosine=torch.sum(task_emb_matrix@torch.transpose(task_emb,0,1))
            in_loss-=inc*in_cosine

            if task==0:
                batch_emb_matrix=torch.zeros((meta_bsz,task_emb.size(-1)))
            batch_emb_matrix[task]=task_emb.squeeze()

            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

        if iteration % test_interval == 0:
            for task in range(valid_bz):
            # Compute meta-validation loss
                learner = copy.deepcopy(model)
                adapt_opt = torch.optim.Adam(
                    learner.parameters(),
                    lr=fast_lr,
                    betas=(0, 0.999)
                )
                adapt_opt.load_state_dict(adapt_opt_state)
                batch = valid_tasks.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                    learner,
                                                                    adapt_opt,
                                                                    loss,
                                                                    test_steps,
                                                                    test_shots,
                                                                    ways,
                                                                    test_bsz,
                                                                    device)
                meta_valid_error += evaluation_error.item()
                meta_valid_accuracy += evaluation_accuracy.item()
            
            for task in range(test_bz):
            # Compute meta-testing loss
                learner = copy.deepcopy(model)
                adapt_opt = torch.optim.Adam(
                    learner.parameters(),
                    lr=fast_lr,
                    betas=(0, 0.999)
                )
                adapt_opt.load_state_dict(adapt_opt_state)
                batch = test_tasks.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                    learner,
                                                                    adapt_opt,
                                                                    loss,
                                                                    1,
                                                                    test_shots,
                                                                    ways,
                                                                    test_bsz,
                                                                    device)
                meta_test_error1 += evaluation_error.item()
                meta_test_accuracy1 += evaluation_accuracy.item()
                learner = copy.deepcopy(model)
                adapt_opt = torch.optim.Adam(
                    learner.parameters(),
                    lr=fast_lr,
                    betas=(0, 0.999)
                )
                adapt_opt.load_state_dict(adapt_opt_state)
                batch = test_tasks.sample()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                    learner,
                                                                    adapt_opt,
                                                                    loss,
                                                                    test_steps,
                                                                    5,
                                                                    ways,
                                                                    test_bsz,
                                                                    device)
                meta_test_error5 += evaluation_error.item()
                meta_test_accuracy5 += evaluation_accuracy.item()
        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_bsz)
        print('Meta Train Accuracy', meta_train_accuracy / meta_bsz)
        print('best valid', bestv / valid_bz,'best test1@test', bestt1 / test_bz,'best test1', besttv1 / test_bz,'best test5@test', bestt5 / test_bz,'best test5', besttv5 / test_bz)
        if iteration % test_interval == 0:
            if meta_valid_accuracy>=bestv:
                bestv=meta_valid_accuracy
                if meta_test_accuracy1>=besttv1:
                    besttv1=meta_test_accuracy1
                if meta_test_accuracy5>=besttv5:
                    besttv5=meta_test_accuracy5
            if meta_test_accuracy1>=bestt1:
                bestt1=meta_test_accuracy1
            if meta_test_accuracy5>=bestt5:
                bestt5=meta_test_accuracy5
            print('Meta Valid Error', meta_valid_error / valid_bz)
            print('Meta Valid Accuracy', meta_valid_accuracy / valid_bz)
            print('Meta Test Error1', meta_test_error1 / test_bz,'Meta Test Error5', meta_test_error5 / test_bz)
            print('Meta Test Accuracy1', meta_test_accuracy1 / test_bz,'Meta Test Accuracy5', meta_test_accuracy5 / test_bz)

        # Track quantities
        train_inner_errors.append(meta_train_error / meta_bsz)
        train_inner_accuracies.append(meta_train_accuracy / meta_bsz)
        if iteration % test_interval == 0:
            valid_inner_errors.append(meta_valid_error / valid_bz)
            valid_inner_accuracies.append(meta_valid_accuracy / valid_bz)
            test_inner_errors.append(meta_test_error1 / test_bz)
            test_inner_accuracies.append(meta_test_accuracy1 / test_bz)

        # Average the accumulated gradients and optimize
        for p in model.parameters():
            p.grad.data.mul_(1.0 / meta_bsz).add_(p.data)
        batch_emb_matrix_t=torch.transpose(batch_emb_matrix,1,0)
        out_cosine=torch.sum(torch.mm(batch_emb_matrix,batch_emb_matrix_t))
        out_cosine/=(meta_bsz*(meta_bsz-1))
        (in_loss/meta_bsz+outc*out_cosine).backward()
        opt.step()



if __name__ == '__main__':
    main()
