import argparse
import os
from datetime import datetime

import torch
import torch.nn as nn
import torchvision
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from data.neural_net.custom_data import get_adjusted_flower_datasets
from data.neural_net.misc import get_logger
from hyperparams.load import get_config

config = get_config()
args = dict(
    max_epochs=200,
    batch_size=64,
    lr=1e-5,
    n_classes=62,
)


class Trainer:
    def __init__(self, model, args, datasets, device):
        self.device = device
        self.model = model.to(self.device)
        self.args = args
        self.loaders = dict(
            train=DataLoader(datasets['train'], batch_size=args['batch_size'],
                             shuffle=True, pin_memory=True),
            val=DataLoader(datasets['val'], batch_size=512))

        trainable_params = filter(lambda p: p.requires_grad, model.parameters())
        self.opt = torch.optim.Adam(trainable_params, self.args['lr'])
        self.criterion = nn.CrossEntropyLoss()

    def save_checkpoint(self, epoch):
        checkpoint = {'state_dict': self.model.state_dict(),
                      'args': self.args,
                      'epoch': epoch}
        save_path = os.path.join(exp_dir, f'model_epoch_{epoch}.pt')
        torch.save(checkpoint, save_path)
        logger.info(f'New checkpoint, saved at {save_path}')

    def train(self):
        val_acc = 0
        for epoch in range(1, self.args['max_epochs'] + 1):
            # Training
            self.model.train()
            self._run_epoch(epoch)

            # Validation
            if epoch % 5 == 0:
                logger.info('Validation:')
                self.model.eval()
                loss, acc = self._run_epoch(epoch, split='val')
                if acc > val_acc:
                    val_acc = acc
                    self.save_checkpoint(epoch)

        logger.info('Finished training')

    def _run_epoch(self, epoch, split='train'):
        losses, running_corrects = 0.0, 0

        logger.info(f'\nEpoch: {epoch} ({split}):')
        for x, y in tqdm(self.loaders[split], leave=0, position=0):

            # Forward the model
            x, y = x.to(self.device), y.to(self.device)
            with torch.set_grad_enabled(split == 'train'):
                output = self.model(x)
                _, preds = torch.max(output, 1)
                loss = self.criterion(output, one_hot(y))

            # Optimize
            if split == 'train':
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

            # Statistics
            losses += loss.item() * x.size(0)
            running_corrects += torch.sum(preds == y)

        n = len(self.loaders[split].dataset)
        loss = losses / n
        acc = (running_corrects.double() / n) * 100
        wandb.log(data={f'loss_{split}': loss}, step=epoch)
        wandb.log(data={f'acc_{split}': acc}, step=epoch)
        logger.info(f'Loss: {loss:.2f}, Acc: {acc:.2f}%')
        return loss, acc


def make_dir(debug=False):
    run_id = datetime.now().isoformat().replace(':', '-').replace('T', '_T_')
    exp_dir = os.path.join(config.dirs['experiments'], 'resnet/flower',
                           'debug' if debug else 'runs', run_id)
    os.makedirs(exp_dir)
    return exp_dir, run_id


def one_hot(y):
    y = y.to(torch.int64)
    y = nn.functional.one_hot(y, num_classes=args['n_classes'])
    y = y.to(torch.float32)
    return y


def get_resnet_for_flower_dataset(n_classes=62):
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, n_classes)
    for param in model.parameters():
        param.requires_grad = False
    for layer in ['layer4', 'fc']:
        for param in getattr(model, layer).parameters():
            param.requires_grad = True
    return model


if __name__ == '__main__':
    p = argparse.ArgumentParser()
    p.add_argument('--debug', action='store_true')
    p.add_argument('--tags', default=[])
    parser = p.parse_args()
    device = utils.setup_device()

    exp_dir, run_id = make_dir(parser.debug)
    logger = get_logger(exp_dir)
    logger.info(f'run_path: {exp_dir}')
    wandb.init(project='resnet_flower',
               entity=config.wandb['entity'], id=run_id, name=run_id,
               tags=['debug'] if parser.debug else parser.tags)
    wandb.config.update(args)

    datasets = get_adjusted_flower_datasets()
    model = get_resnet_for_flower_dataset(args['n_classes'])
    trainer = Trainer(model, args, datasets, device)
    trainer.train()
