import argparse

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from model import *
from utils import *

from betty.engine import Engine
from betty.problems import ImplicitProblem
from betty.configs import Config, EngineConfig


parser = argparse.ArgumentParser(description="Meta_Weight_Net")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--strategy", type=str, default="default")
parser.add_argument("--rollback", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--meta_net_hidden_size", type=int, default=100)
parser.add_argument("--meta_net_num_layers", type=int, default=1)

parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--dampening", type=float, default=0.0)
parser.add_argument("--nesterov", type=bool, default=False)
parser.add_argument("--weight_decay", type=float, default=5e-4)
parser.add_argument("--meta_lr", type=float, default=1e-5)
parser.add_argument("--meta_weight_decay", type=float, default=0.0)

parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--batch_size", type=int, default=100)

parser.add_argument("--correct", action="store_true")
parser.add_argument("--reg_str", type=float, default=1.)
args = parser.parse_args()
print(args)
set_seed(args.seed)


transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

dataset_cls = datasets.CIFAR10 if args.dataset == "cifar10" else datasets.CIFAR100
train_dataset = dataset_cls(
    root="./data", train=True, download=True, transform=transform_train
)
test_dataset = dataset_cls(root="./data", train=False, transform=transform_test)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    pin_memory=True,
)
meta_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    pin_memory=True,
)
meta_dataloader2 = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    pin_memory=True,
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    pin_memory=True,
    num_workers=2,
)
EPOCH_LEN = len(train_dataset) // args.batch_size
TOTAL_ITER = EPOCH_LEN * 60
DECAY_ITER1 = EPOCH_LEN * 50
DECAY_ITER2 = EPOCH_LEN * 80
DECAY_ITER3 = EPOCH_LEN * 105
DECAY_ITER4 = EPOCH_LEN * 125


class Outer(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs, ema_outputs = self.inner(inputs)
        loss = F.cross_entropy(outputs, labels.long())

        return loss

    def configure_train_data_loader(self):
        return meta_dataloader

    def configure_module(self):
        meta_net = MLP(
            input_size=2,
            hidden_size=args.meta_net_hidden_size,
            num_layers=args.meta_net_num_layers,
        )
        return meta_net

    def configure_optimizer(self):
        meta_optimizer = optim.Adam(
            self.module.parameters(),
            lr=args.meta_lr,
            weight_decay=args.meta_weight_decay,
        )
        return meta_optimizer


class Inner(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs, ema_outputs = self.forward(inputs)

        # loss calculation
        loss_vector = F.cross_entropy(outputs, labels.long(), reduction="none")
        loss_vector_reshape = torch.reshape(loss_vector, (-1, 1))

        # ema part
        ema_prob = F.softmax(ema_outputs, dim=-1)
        ema_loss_vector = torch.sum(-F.log_softmax(outputs, dim=-1) * ema_prob, dim=-1)
        ema_loss_vector_reshape = torch.reshape(ema_loss_vector, (-1, 1))

        # reweighting
        meta_inputs = torch.cat(
            [loss_vector_reshape.detach(), ema_loss_vector_reshape.detach()], dim=1
        )
        weight = self.outer(meta_inputs, self._global_step)
        loss = torch.mean(weight * loss_vector_reshape)

        return loss

    def configure_train_data_loader(self):
        return train_dataloader

    def configure_module(self):
        return ResNet32(args.dataset == "cifar10" and 10 or 100)

    def configure_optimizer(self):
        optimizer = optim.SGD(
            self.module.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            dampening=args.dampening,
            weight_decay=args.weight_decay,
            nesterov=args.nesterov,
        )
        return optimizer

    def configure_scheduler(self):
        scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[DECAY_ITER1, DECAY_ITER2, DECAY_ITER3, DECAY_ITER4],
            gamma=0.2,
        )
        return scheduler

    def param_callback(self):
        if args.strategy == "default":
            self.module.ema_update()
        elif args.strategy in ["zero", "distributed"]:
            self.module.module.ema_update()

    def epoch_callback(self):
        print("Epoch:", self.epoch_counter, "|| global step:", self._global_step)



best_acc = -1


class ReweightingEngine(Engine):
    @torch.no_grad()
    def validation(self):
        correct = 0
        total = 0
        global best_acc
        for x, target in test_dataloader:
            x, target = x.to(args.device), target.to(args.device)
            with torch.no_grad():
                out, *_ = self.inner(x)
            correct += (out.argmax(dim=1) == target).sum().item()
            total += x.size(0)
        acc = correct / total * 100
        if best_acc < acc:
            best_acc = acc
        torch.save(self.inner.state_dict(), "save_{}/cls_{}.pt".format(args.dataset, self.global_step))
        torch.save(self.outer.state_dict(), "save_{}/mwn_{}.pt".format(args.dataset, self.global_step))
        return {"acc": acc, "best_acc": best_acc}


outer_config = Config(type="darts", fp16=args.fp16, log_step=100, retain_graph=True)
inner_config = Config(type="darts", fp16=args.fp16, unroll_steps=2, darts_alpha=0.01)
engine_config = EngineConfig(
    train_iters=TOTAL_ITER,
    valid_step=EPOCH_LEN,
    strategy=args.strategy,
    roll_back=args.rollback,
)
outer = Outer(name="outer", config=outer_config)
inner = Inner(name="inner", config=inner_config)

problems = [outer, inner]
u2l = {outer: [inner]}
l2u = {inner: [outer]}
dependencies = {"l2u": l2u, "u2l": u2l}

engine = ReweightingEngine(
    config=engine_config, problems=problems, dependencies=dependencies
)
engine.run()
