import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
import datasets
import tokenizers
import transformers


def cifar100_datamodule(root, batch_size, num_workers, img_size=32):
    test_tsfms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    if img_size != 32:
        test_tsfms.transforms.append(transforms.Resize(img_size))
    train_tsfms = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            test_tsfms,
        ]
    )

    train = torchvision.datasets.CIFAR100(
        root, train=True, download=True, transform=train_tsfms
    )
    train, val = torch.utils.data.random_split(
        train, [40000, 10000], generator=torch.Generator().manual_seed(42)
    )
    val.transform = test_tsfms
    test = torchvision.datasets.CIFAR100(
        root, train=False, download=True, transform=test_tsfms
    )

    return pl.LightningDataModule.from_datasets(
        train,
        val,
        test,
        batch_size=batch_size,
        num_workers=num_workers,
    )


class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock1(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(
                lambda x: F.pad(
                    x[:, :, ::2, ::2],
                    (0, 0, 0, 0, planes // 4, planes // 4),
                    "constant",
                    0,
                )
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(out + self.shortcute(x))
        return out


class ResNet32(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.in_planes = 16
        num_blocks = [5, 5, 5]

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock1(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock1.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


class BasicBlock2(nn.Module):
    def __init__(self, channels, kernel_size, in_channels=None, downsample=False):
        super().__init__()

        # By default, in_channels = channels (number of out channels)
        if in_channels is None:
            in_channels = channels

        stride = 2 if downsample else 1

        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size, stride, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

        # If this block downsamples input or changes number of channels,
        # we need to do the same transformation (using 1x1 conv) to the input
        # which will be added to the output of the block.
        if downsample or in_channels != channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(channels),
            )
        # If this block doesn't do that, shortcut is just the identity function.
        else:
            self.shortcut = nn.Sequential()

    def forward(self, input):
        x = F.relu(self.bn1(self.conv1(input)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + self.shortcut(input))


class ResNet34(nn.Module):
    """A 34-layer ResNet with modifications for CIFAR-100 (32x32 inputs, 100 classes)"""

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        self.section1 = nn.Sequential(
            BasicBlock2(64, 3), BasicBlock2(64, 3), BasicBlock2(64, 3)
        )

        self.section2 = nn.Sequential(
            BasicBlock2(128, 3, in_channels=64, downsample=True),
            BasicBlock2(128, 3),
            BasicBlock2(128, 3),
            BasicBlock2(128, 3),
        )

        self.section3 = nn.Sequential(
            BasicBlock2(256, 3, in_channels=128, downsample=True),
            BasicBlock2(256, 3),
            BasicBlock2(256, 3),
            BasicBlock2(256, 3),
            BasicBlock2(256, 3),
        )

        self.section4 = nn.Sequential(
            BasicBlock2(512, 3, in_channels=256, downsample=True),
            BasicBlock2(512, 3),
            BasicBlock2(512, 3),
        )

        self.fc = nn.Linear(512, 100)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.section1(x)
        x = self.section2(x)
        x = self.section3(x)
        x = self.section4(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        return self.fc(torch.flatten(x, 1))


# For Transformer problem
class WMT14_EN_DE_DataModule(pl.LightningDataModule):

    def __init__(self, batch_size, num_workers):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = None
    
    def train_tokenizer(self):
        # First we download the dataset and train+save our bpe tokenizer
        dataset = datasets.load_dataset("wmt14", "de-en", split="train+validation+test")

        tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())
        tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()

        trainer = tokenizers.trainers.BpeTrainer(
            special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"],
            vocab_size=37000,
        )

        def batch_iterator(dataset, batch_size=1000):
            for i in range(0, len(dataset), batch_size):
                datapoints = dataset[i : i + batch_size]["translation"]
                yield [x["en"] for x in datapoints] + [x["de"] for x in datapoints]

        tokenizer.train_from_iterator(
            batch_iterator(dataset), trainer=trainer, length=2 * len(dataset)
        )
        tokenizer.save("data/bpe-tokenizer-wmt14-en-de.json")
    


        # Then we tokenize the dataset. Doing this once will cache it to disk
        cls.get_datamodule()
    
    def prepare_data(self):
        pass

    @classmethod
    def get_datamodule(cls, num_workers=4):
        dataset = datasets.load_dataset("wmt14", "de-en")
        tokenizer = transformers.PreTrainedTokenizerFast(
            tokenizer_file="data/bpe-tokenizer-wmt14-en-de.json"
        )

        def tokenization(example):
            en_tokens = tokenizer([x["en"] for x in example["translation"]])
            de_tokens = tokenizer([x["de"] for x in example["translation"]])
            return {
                "en_input_ids": en_tokens.input_ids,
                "de_input_ids": de_tokens.input_ids,
            }

        for split in dataset.keys():
            dataset[split] = dataset[split].map(tokenization, batched=True)
            dataset[split].set_format(
                "torch",
                columns=["en_input_ids", "de_input_ids"],
                output_all_columns=True,
            )
        data_collator = transformers.DataCollatorForSeq2Seq(tokenizer=tokenizer)

        return pl.LightningDataModule.from_datasets(
            dataset["train"],
            dataset["validation"],
            dataset["test"],
            batch_size=cls.batch_size,
            num_workers=num_workers,
        )