import os
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from ray import tune
from get_loaders import *
from simple_cnn import LeNet5
from resnext import ResNeXt29_2x64d
from resnet import *
from utils.utils import get_linear_schedule_with_warmup


def train(model, optimizer, scheduler, train_loader):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()


def test(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.cuda(), target.cuda()
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return correct / total


class CIFAR(tune.Trainable):
    def _setup(self, config):
        args = config["args"]
        task_config = config["task_config"]
        Opt = config["optimizer"]
        # arguments
        fpath = task_config["data_root"]
        batch_size = task_config["batch_size"]
        max_t = task_config["max_t"]
        warmup = config["warmup"]
        decay_rate = config["decay_rate"]
        num_class = task_config["num_class"]

        self.model = ResNet34(num_classes=num_class).cuda()
        self.train_loader, self.test_loader = get_cifar_loaders(fpath, num_class, batch_size)

        # remove args and task_config so that the remaining keys are purely related to optimizer
        del (
            config["args"],
            config["task_config"],
            config["optimizer"],
            config["warmup"],
            config["decay_rate"],
        )

        self.optimizer = Opt(self.model.parameters(), **config)
        warmup_iters = int(warmup * len(self.train_loader))
        total_iters = max_t * len(self.train_loader)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, warmup_iters, total_iters, decay_rate
        )

    def _train(self):
        train(self.model, self.optimizer, self.scheduler, self.train_loader)
        acc = test(self.model, self.test_loader)
        return {"mean_accuracy": acc, "early_stop": False}

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))
