import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer
from mmcv.utils import get_logger
from mmcv.utils.logging import print_log
from torch.utils.data import DataLoader, Dataset

from openmixup.hooks import PreciseBNHook


class ExampleDataset(Dataset):

    def __init__(self):
        self.index = 0

    def __getitem__(self, idx):
        results = dict(imgs=torch.tensor([1.0], dtype=torch.float32))
        return results

    def __len__(self):
        return 1


class BiggerDataset(ExampleDataset):

    def __init__(self, fixed_values=range(0, 12)):
        assert len(self) == len(fixed_values)
        self.fixed_values = fixed_values

    def __getitem__(self, idx):
        results = dict(
            imgs=torch.tensor([self.fixed_values[idx]], dtype=torch.float32))
        return results

    def __len__(self):
        # a bigger dataset
        return 12


class ExampleModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = nn.Linear(1, 1)
        self.bn = nn.BatchNorm1d(1)
        self.test_cfg = None

    def forward(self, imgs, return_loss=False):
        return self.bn(self.conv(imgs))

    def train_step(self, data_batch, optimizer, **kwargs):
        outputs = {
            'loss': 0.5,
            'log_vars': {
                'accuracy': 0.98
            },
            'num_samples': 1
        }
        return outputs


class SingleBNModel(ExampleModel):

    def __init__(self):
        super().__init__()
        self.bn = nn.BatchNorm1d(1)
        self.test_cfg = None

    def forward(self, imgs, return_loss=False):
        return self.bn(imgs)


class GNExampleModel(ExampleModel):

    def __init__(self):
        super().__init__()
        self.conv = nn.Linear(1, 1)
        self.bn = nn.GroupNorm(1, 1)
        self.test_cfg = None


class NoBNExampleModel(ExampleModel):

    def __init__(self):
        super().__init__()
        self.conv = nn.Linear(1, 1)
        self.test_cfg = None

    def forward(self, imgs, return_loss=False):
        return self.conv(imgs)


def test_precise_bn():
    optimizer_cfg = dict(
        type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

    test_dataset = ExampleDataset()
    loader = DataLoader(test_dataset, batch_size=2)
    model = ExampleModel()
    optimizer = build_optimizer(model, optimizer_cfg)
    logger = get_logger('precise_bn')
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)

    with pytest.raises(AssertionError):
        # num_samples must be larger than 0
        precise_bn_hook = PreciseBNHook(num_samples=-1)
        runner.register_hook(precise_bn_hook)
        runner.run([loader], [('train', 1)])

    with pytest.raises(AssertionError):
        # interval must be larger than 0
        precise_bn_hook = PreciseBNHook(interval=0)
        runner.register_hook(precise_bn_hook)
        runner.run([loader], [('train', 1)])

    with pytest.raises(AssertionError):
        # interval must be larger than 0
        runner = EpochBasedRunner(
            model=model,
            batch_processor=None,
            optimizer=optimizer,
            logger=logger,
            max_epochs=1)
        precise_bn_hook = PreciseBNHook(interval=0)
        runner.register_hook(precise_bn_hook)
        runner.run([loader], [('train', 1)])

    with pytest.raises(AssertionError):
        # only support EpochBaseRunner
        runner = IterBasedRunner(
            model=model,
            batch_processor=None,
            optimizer=optimizer,
            logger=logger,
            max_epochs=1)
        precise_bn_hook = PreciseBNHook(interval=2)
        runner.register_hook(precise_bn_hook)
        print_log(runner)
        runner.run([loader], [('train', 1)])

    # test non-DDP model
    test_bigger_dataset = BiggerDataset()
    loader = DataLoader(test_bigger_dataset, batch_size=2)
    loaders = [loader]
    precise_bn_hook = PreciseBNHook(num_samples=4)
    assert precise_bn_hook.num_samples == 4
    assert precise_bn_hook.interval == 1
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)
    runner.register_hook(precise_bn_hook)
    runner.run(loaders, [('train', 1)])

    # test DP model
    test_bigger_dataset = BiggerDataset()
    loader = DataLoader(test_bigger_dataset, batch_size=2)
    loaders = [loader]
    precise_bn_hook = PreciseBNHook(num_samples=4)
    assert precise_bn_hook.num_samples == 4
    assert precise_bn_hook.interval == 1
    model = MMDataParallel(model)
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)
    runner.register_hook(precise_bn_hook)

    # test model w/ gn layer
    loader = DataLoader(test_bigger_dataset, batch_size=2)
    loaders = [loader]
    precise_bn_hook = PreciseBNHook(num_samples=4)
    assert precise_bn_hook.num_samples == 4
    assert precise_bn_hook.interval == 1
    model = GNExampleModel()
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)
    runner.register_hook(precise_bn_hook)
    runner.run(loaders, [('train', 1)])

    # test model without bn layer
    loader = DataLoader(test_bigger_dataset, batch_size=2)
    loaders = [loader]
    precise_bn_hook = PreciseBNHook(num_samples=4)
    assert precise_bn_hook.num_samples == 4
    assert precise_bn_hook.interval == 1
    model = NoBNExampleModel()
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)
    runner.register_hook(precise_bn_hook)
    runner.run(loaders, [('train', 1)])

    # test how precise it is
    loader = DataLoader(test_bigger_dataset, batch_size=2)
    loaders = [loader]
    precise_bn_hook = PreciseBNHook(num_samples=12)
    assert precise_bn_hook.num_samples == 12
    assert precise_bn_hook.interval == 1
    model = SingleBNModel()
    runner = EpochBasedRunner(
        model=model,
        batch_processor=None,
        optimizer=optimizer,
        logger=logger,
        max_epochs=1)
    runner.register_hook(precise_bn_hook)
    runner.run(loaders, [('train', 1)])
    imgs_list = list()
    for loader in loaders:
        for i, data in enumerate(loader):
            imgs_list.append(np.array(data['imgs']))
    mean = np.mean([np.mean(batch) for batch in imgs_list])
    # bassel correction used in Pytorch, therefore ddof=1
    var = np.mean([np.var(batch, ddof=1) for batch in imgs_list])
    assert np.equal(mean, np.array(
        model.bn.running_mean)), (mean, np.array(model.bn.running_mean))
    assert np.equal(var, np.array(
        model.bn.running_var)), (var, np.array(model.bn.running_var))

    @pytest.mark.skipif(
        not torch.cuda.is_available(), reason='requires CUDA support')
    def test_ddp_model_precise_bn():
        # test DDP model
        test_bigger_dataset = BiggerDataset()
        loader = DataLoader(test_bigger_dataset, batch_size=2)
        loaders = [loader]
        precise_bn_hook = PreciseBNHook(num_samples=5)
        assert precise_bn_hook.num_samples == 5
        assert precise_bn_hook.interval == 1
        model = ExampleModel()
        model = MMDistributedDataParallel(
            model.cuda(),
            device_ids=[torch.cuda.current_device()],
            broadcast_buffers=False,
            find_unused_parameters=True)
        runner = EpochBasedRunner(
            model=model,
            batch_processor=None,
            optimizer=optimizer,
            logger=logger,
            max_epochs=1)
        runner.register_hook(precise_bn_hook)
        runner.run(loaders, [('train', 1)])
