import os
import math
import itertools
import numpy as np
from tqdm import tqdm
import argparse
from omegaconf import OmegaConf

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import pytorch_lightning as pl

from .resnet import resnet18, resnet34
from .attention import AttentionStack, LayerNorm, AddBroadcastPosEmbed
from .utils import shift_dim, load_vqgan


class VideoGPT(pl.LightningModule):
    def __init__(self, args, new_args=None):
        super().__init__()
        self.hparams.update(vars(args))
        if new_args is not None:
            self.hparams.update(vars(new_args))

        # Load VQ-GAN model
        config_name = [
            f
            for f in os.listdir(f"{self.hparams.vqgan_path}/configs")
            if "project.yaml" in f
        ][0]
        config = OmegaConf.load(f"{self.hparams.vqgan_path}/configs/{config_name}")
        ckpt = f"{self.hparams.vqgan_path}/checkpoints/last.ckpt"
        self.vqgan = load_vqgan(config, ckpt)
        for p in self.vqgan.parameters():
            p.requires_grad = False
        self.vqgan.eval()

        if self.hparams.use_vq_cond:
            # ResNet18 for frame conditioning
            self.use_frame_cond = args.n_cond_frames > 0
            if self.use_frame_cond:
                frame_cond_shape = (
                    args.n_cond_frames,
                    args.resolution // min(self.hparams.downscale, 8),
                    args.resolution // min(self.hparams.downscale, 8),
                    240,
                )
                self.resnet = resnet18(
                    1,
                    (
                        1,
                        min(self.hparams.downscale, 8),
                        min(self.hparams.downscale, 8),
                    ),
                    resnet_dim=240,
                )
                self.cond_pos_embd = AddBroadcastPosEmbed(
                    shape=frame_cond_shape[:-1], embd_dim=frame_cond_shape[-1]
                )
            else:
                frame_cond_shape = None

        # define attention layer
        self.shape = (
            args.sequence_length,
            args.resolution // self.hparams.downscale,
            args.resolution // self.hparams.downscale,
        )

        self.vocab_size = self.vqgan.quantize.n_e
        self.embedding_in = nn.Embedding(self.vocab_size, args.hidden_dim)
        self.embedding_in.weight.data.normal_(std=0.02)

        if self.hparams.action_cond:
            self.action_in = nn.Linear(
                self.hparams.action_dim, args.hidden_dim, bias=False
            )

        self.fc_in = nn.Linear(self.vqgan.embedding_dim, args.hidden_dim, bias=False)
        self.fc_in.weight.data.normal_(std=0.02)

        self.attn_stack = AttentionStack(
            self.shape,
            args.hidden_dim,
            args.heads,
            args.layers,
            args.dropout,
            args.attn_type,
            args.attn_dropout,
            args.class_cond_dim,
            frame_cond_shape,
        )

        self.norm = LayerNorm(args.hidden_dim, args.class_cond_dim)

        self.fc_out = nn.Linear(args.hidden_dim, self.vocab_size, bias=False)
        self.fc_out.weight.data.copy_(torch.zeros(self.vocab_size, args.hidden_dim))

        if self.hparams.augment_image:
            w_, h_ = self.hparams.augment_magnitude, self.hparams.augment_magnitude
            self.nn_pad = nn.ZeroPad2d((w_, w_, h_, h_, 0, 0))

        # caches for faster decoding (if necessary)
        self.frame_cond_cache = None

        self.save_hyperparameters()

    def embed(self, x):
        x = self.embedding_in(x)
        return x

    def cond_embed(self, x):
        return self.resnet(x)

    def forward(
        self,
        targets,
        cond,
        decode_step=None,
        decode_idx=None,
    ):
        if self.use_frame_cond:
            if decode_step is None:
                cond["frame_cond"] = self.cond_pos_embd(
                    self.cond_embed(cond["frame_cond"])
                )
            elif decode_step == 0:
                self.frame_cond_cache = self.cond_pos_embd(
                    self.cond_embed(cond["frame_cond"])
                )
                cond["frame_cond"] = self.frame_cond_cache
            else:
                cond["frame_cond"] = self.frame_cond_cache

        input_targets = targets

        h = self.embed(input_targets)
        if self.hparams.action_cond:
            # cond[action]: B, T, D
            # h: B, T, H, W, D = > D
            if decode_step is None:
                h = h + self.action_in(cond["action"].unsqueeze(2).unsqueeze(2))
            else:
                # at inference time, h shape: [B, 1, 1, 1, D] but action: [B, T, D]
                h = h + self.action_in(
                    cond["action"][:, decode_idx[0] : decode_idx[0] + 1]
                    .unsqueeze(2)
                    .unsqueeze(2)
                )
        h = self.attn_stack(h, cond, decode_step, decode_idx, fast_long_decode)
        h = self.norm(h, cond)
        logits = self.fc_out(h)

        loss = F.cross_entropy(shift_dim(logits, -1, 1), targets)
        return loss, logits

    def process_batch(self, batch, mode="eval"):
        x = batch["video"]

        cond = dict()
        with torch.no_grad():
            if mode == "training" and self.hparams.augment_image:
                x = self.translate(x)

        if self.hparams.action_cond:
            action = batch["action"]
            cond["action"] = F.pad(action, (0, 0, 1, 0))  # [B, T, act_dim]

        if self.use_frame_cond:
            cond["frame_cond"] = x[:, :, : self.hparams.n_cond_frames]

        B, C, T, H, W = x.shape

        x = shift_dim(x, 1, 2)
        x = x.reshape(B * T, C, H, W)

        # implementation trick to avoid high memory usage of VQ-VAE
        # extracts discret codes from chunks
        encodings = []
        for i in range(math.ceil(B * T / self.hparams.vqgan_chunk)):
            start = self.hparams.vqgan_chunk * i
            end = self.hparams.vqgan_chunk * (i + 1)
            encoding = self.vqgan.encode_code(x[start:end])
            encodings.append(encoding)
        targets = torch.cat(encodings)
        targets = targets.reshape(B, T, targets.shape[1], targets.shape[2])

        return targets, cond

    def get_reconstruction(self, video):
        # video: B, C, T, H, W
        B, C, T, H, W = video.shape
        video = video.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
        recon_video = self.vqgan(video)[0]
        recon_video = recon_video.reshape(B, T, C, H, W)
        recon_video = recon_video.permute(0, 2, 1, 3, 4)

        return recon_video

    def training_step(self, batch, batch_idx):
        targets, cond = self.process_batch(batch, mode="training")
        loss, _ = self(targets, cond)
        self.log("train/loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        targets, cond = self.process_batch(batch, mode="eval")
        loss, _ = self(targets, cond)
        self.log("val/loss", loss, prog_bar=True)

    def translate(self, x):
        h_aug = np.random.randint(
            int(2 * self.hparams.augment_magnitude), size=x.shape[0]
        )
        w_aug = np.random.randint(
            int(2 * self.hparams.augment_magnitude), size=x.shape[0]
        )
        b, c, t, h, w = x.shape
        x = self.nn_pad(x)
        x = torch.stack(
            [
                x[i, :, :, h_aug[i] : h_aug[i] + h, w_aug[i] : w_aug[i] + w]
                for i in range(b)
            ]
        )
        return x

    def sample(
        self,
        n,
        batch=None,
        top_k=0,
    ):
        device = self.embedding_in.weight.device

        cond = dict()
        gt_targets, cond = self.process_batch(batch)
        if self.use_frame_cond:
            cond_frames = (cond["frame_cond"].clone() + 1) / 2.0

        shape = (self.shape[0], self.shape[1], self.shape[2])

        samples = torch.zeros((n,) + shape).long().to(device)
        idxs = list(itertools.product(*[range(s) for s in shape]))

        with torch.no_grad():
            prev_idx = None
            for i, idx in enumerate(tqdm(idxs)):
                batch_idx_slice = (slice(None, None), *[slice(i, i + 1) for i in idx])
                batch_idx = (slice(None, None), *idx)

                if prev_idx is None:
                    # set arbitrary input values for the first token
                    # does not matter what value since it will be shifted anyways
                    samples_slice = samples[batch_idx_slice]
                else:
                    samples_slice = samples[prev_idx]

                if i < self.shape[1] * self.shape[2] * self.hparams.n_cond_frames:
                    if prev_idx is None:
                        sample_inputs = gt_targets[batch_idx_slice]
                    else:
                        sample_inputs = gt_targets[prev_idx]
                else:
                    sample_inputs = samples_slice

                logits = self(sample_inputs, cond, decode_step=i, decode_idx=idx,)[1]
                # squeeze all possible dim except batch dimension
                logits = (
                    logits.squeeze().unsqueeze(0)
                    if logits.shape[0] == 1
                    else logits.squeeze()
                )

                if i < self.shape[1] * self.shape[2] * self.hparams.n_cond_frames:
                    samples[batch_idx] = gt_targets[batch_idx]
                else:
                    if top_k > 0:
                        # Remove all tokens with a probability less than the last token of the top-k
                        indices_to_remove = (
                            logits < torch.topk(logits, top_k)[0][..., -1, None]
                        )
                        logits[indices_to_remove] = -float("Inf")

                    probs = F.softmax(logits, dim=-1)
                    samples[batch_idx] = torch.multinomial(probs, 1).squeeze(-1)

                prev_idx = batch_idx_slice

            # sample: BTHW
            B, T, H_, W_ = samples.shape
            samples = samples.reshape(B * T, H_, W_)
            samples = self.vqgan.decode_code(samples)
            samples = ((samples + 1.0)) / 2.0
            samples = samples.clamp(0, 1)
            _, C, H, W = samples.shape
            samples = samples.reshape(B, T, C, H, W).permute(0, 2, 1, 3, 4)

        return samples

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999)
        )

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--vqgan', type=str, default='kinetics_stride4x4x4',
                            help='path to vqgan ckpt, or model name to download pretrained')
        parser.add_argument('--n_cond_frames', type=int, default=0)
        parser.add_argument('--class_cond', action='store_true')

        # VideoGPT hyperparmeters
        parser.add_argument('--hidden_dim', type=int, default=576)
        parser.add_argument('--heads', type=int, default=4)
        parser.add_argument('--layers', type=int, default=8)
        parser.add_argument('--dropout', type=float, default=0.2)
        parser.add_argument('--attn_type', type=str, default='full',
                            choices=['full', 'sparse'])
        parser.add_argument('--attn_dropout', type=float, default=0.3)

        return parser
