import argparse

import numpy
import torch
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from model import *


# load model
parser = argparse.ArgumentParser(description="Meta_Weight_Net")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--dataset", type=str, default="cifar10")

parser.add_argument("--meta_net_hidden_size", type=int, default=100)
parser.add_argument("--meta_net_num_layers", type=int, default=1)

parser.add_argument("--random", action="store_true")

args = parser.parse_args()

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

dataset_cls = datasets.CIFAR10 if args.dataset == "cifar10" else datasets.CIFAR100
train_dataset = dataset_cls(
    root="./data", train=True, download=True, transform=transform_test
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=100,
    shuffle=False,
    pin_memory=True,
)
average_list = [23000, 23500, 24000, 24500, 25000]

sorted_idx = numpy.random.permutation(len(train_dataset))
if not args.random:
    weights_total = []
    for idx in average_list:
        classifier = ResNet32(args.dataset == "cifar10" and 10 or 100)
        classifier.load_state_dict(torch.load("save_{}/cls_{}.pt".format(args.dataset, idx))["module"])
        classifier.to(args.device)
        classifier.eval()

        mwn = MLP(
            input_size=2,
            hidden_size=args.meta_net_hidden_size,
            num_layers=args.meta_net_num_layers,
        )
        mwn.load_state_dict(torch.load("save_{}/mwn_{}.pt".format(args.dataset, idx))["module"])
        mwn.to(args.device)
        cur_iter = idx
        mwn.eval()

        weights = []
        with torch.no_grad():
            for batch in train_dataloader:
                inputs, labels = batch
                inputs, labels = inputs.to(args.device), labels.to(args.device)
                outputs, ema_outputs = classifier(inputs)
                loss = torch.nn.functional.cross_entropy(
                    outputs, labels.long(), reduction="none"
                )
                loss = torch.reshape(loss, (-1, 1))
                ema_prob = torch.nn.functional.softmax(ema_outputs, dim=-1)
                ema_loss = torch.sum(
                    -torch.nn.functional.log_softmax(outputs, dim=-1) * ema_prob, dim=-1
                )
                ema_loss = torch.reshape(ema_loss, (-1, 1))
                weight = mwn(
                    torch.cat([loss.detach(), ema_loss.detach()], dim=1), test=True
                )
                weights.extend(weight.squeeze().cpu().numpy())
        weights_total.append(weights)
    weights_total = numpy.array(weights_total).mean(axis=0)
    sorted_idx = numpy.argsort(weights_total)
    sorted_idx = sorted_idx[::-1]
    sorted_weight = [weights_total[i] for i in sorted_idx]
    print(sorted_weight[:100])
    print(sorted_weight[-100:])
filename = "sorted_index.pt" if not args.random else "random_index.pt"
torch.save(sorted_idx, "sorted_index.pt")
