import tempfile

import matplotlib.pyplot as plt
import numpy as np
import scipy.io.wavfile as wavfile
import skimage
import torch
import torchvision
import wandb
from einops import rearrange
from PIL import Image
from skvideo.datasets import bigbuckbunny, bikes
from torch.utils.data import Dataset
from torchaudio.io import StreamWriter
from torchvision.transforms import v2

from experiments.neural_datasets.inr_utils import make_grid, make_image_grid


def get_cameraman_tensor(sidelength, normalize=False, use_old=None):
    if use_old is not None:
        img = Image.open(use_old)
    else:
        img = Image.fromarray(skimage.data.camera())
    transform = v2.Compose(
        [
            v2.Resize(sidelength),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
        ]
        + ([v2.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))] if normalize else [])
    )
    img = transform(img)
    return img


class ImageFitting(Dataset):
    def __init__(self, sidelength, coord_range=None, normalize=False, use_old=None):
        super().__init__()
        self.sidelength = sidelength
        coord_range = coord_range if coord_range is not None else (-1, 1)
        self.normalize = normalize
        img = get_cameraman_tensor(sidelength, normalize=normalize, use_old=use_old)
        self.data = img.permute(1, 2, 0).view(-1, 1)
        self.coords = make_image_grid([sidelength, sidelength], batch_size=1,
                                      coord_range=coord_range)[0]

        self.input_dim = 2
        self.output_dim = 1

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.coords[idx], self.data[idx], 0.0

    def inv_normalize(self, x):
        if self.normalize:
            return x * 0.5 + 0.5
        return x

    def savefigs(self, eval_model_output, eval_img_grad, eval_img_laplacian, filename):
        num_figures = 3
        fig, axes = plt.subplots(
            1, num_figures, figsize=(6 * num_figures, 6), squeeze=False
        )
        axes[0, 0].imshow(
            eval_model_output.cpu().view(self.sidelength, self.sidelength).numpy()
        )

        axes[0, 1].imshow(
            eval_img_grad.cpu().view(self.sidelength, self.sidelength).numpy()
        )
        axes[0, 2].imshow(
            eval_img_laplacian.cpu().view(self.sidelength, self.sidelength).numpy()
        )
        plt.savefig(f"{filename}.png", dpi=300, bbox_inches="tight")
        plt.close(fig)

    def wandb(self, eval_model_output, eval_img_grad=None, eval_img_laplacian=None, title=None):
        if eval_img_grad is None or eval_img_laplacian is None:
            images = eval_model_output.cpu().unsqueeze(0)
        else:
            images = torch.stack(
                [
                    eval_model_output.cpu(),
                    eval_img_grad.cpu(),
                    eval_img_laplacian.cpu(),
                ],
                dim=0,
            )
        images = rearrange(
            images, "b (h w) c -> b c h w", h=self.sidelength, w=self.sidelength
        )
        grid_img = torchvision.utils.make_grid(images, scale_each=True, normalize=True)
        grid_img = grid_img.permute(1, 2, 0).cpu().numpy()
        return {
            "image": wandb.Image(grid_img, caption=title),
        }


class AudioFitting(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.rate, self.data = wavfile.read(filename)

        self.normalize = False
        self.data = torch.from_numpy(self.data).float()
        if len(self.data.shape) == 1:
            self.data = self.data.unsqueeze(1)
        self.coords = make_grid(
            [len(self.data)], coord_range=[-100, 100], batch_size=1
        )[0]

        self.input_dim = 1
        self.output_dim = 1

    def to(self, device):
        self.coords = self.coords.to(device)
        self.data = self.data.to(device)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.coords[idx], self.data[idx], 0.0

    def inv_normalize(self, x):
        if self.normalize:
            return x * 0.5 + 0.5
        return x

    def savefigs(self, eval_model_output, filename):
        wavfile.write(
            f"{filename}.wav", self.rate, eval_model_output.cpu().numpy().squeeze()
        )
        # Save difference between ground truth and prediction
        wavfile.write(
            f"{filename}_diff.wav",
            self.rate,
            (eval_model_output.cpu() - self.data).numpy().squeeze(),
        )
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        axes[0].plot(self.data)
        axes[0].set_title("Ground truth")
        axes[1].plot(eval_model_output.cpu().numpy().squeeze())
        axes[1].set_title("Prediction")
        plt.savefig(f"{filename}.png", dpi=300, bbox_inches="tight")
        plt.close(fig)

    def wandb(self, eval_model_output, title=None):
        return {
            "audio": wandb.Audio(
                eval_model_output.cpu().numpy().squeeze(),
                sample_rate=self.rate,
                caption=title,
            ),
            "audio_diff": wandb.Audio(
                (eval_model_output.cpu() - self.data).numpy().squeeze(),
                sample_rate=self.rate,
                caption=f"{title}_diff",
            ),
        }


class VideoFitting(Dataset):
    def __init__(self, video_path=None, normalize=False):
        super().__init__()
        if video_path is None:
            video_path = bikes()
        video = torchvision.io.read_video(video_path, pts_unit="sec")

        self.normalize = normalize
        self.normalize_transform = [v2.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))] if normalize else []
        self.inverse_normalize_transform = [v2.Normalize(torch.Tensor([-1.0]), torch.Tensor([2.0]))] if normalize else []
        transform = v2.Compose(
            [
                v2.ToDtype(torch.float32, scale=True),
            ]
            + self.normalize_transform
        )
        self.data = transform(video[0])
        self.fps = video[2]["video_fps"]

        self.dims = self.data.shape[:3]
        self.data = rearrange(self.data, "t h w c -> (t h w) c")
        self.coords = make_grid(self.dims, batch_size=1)[0]

        self.input_dim = 3
        self.output_dim = 3

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.coords[idx], self.data[idx], 0.0

    def inv_normalize(self, x):
        if self.normalize:
            return x * 0.5 + 0.5
        return x

    def savefigs(self, eval_model_output, filename):
        eval_model_output = rearrange(
            eval_model_output.cpu(),
            "(t h w) c -> t h w c",
            t=self.dims[0],
            h=self.dims[1],
            w=self.dims[2],
        )
        transform = v2.Compose(
            self.inverse_normalize_transform +
            [
                v2.ToDtype(torch.uint8, scale=True),
            ]
        )
        eval_model_output = transform(eval_model_output)
        torchvision.io.write_video(
            f"{filename}.mp4",
            eval_model_output,
            fps=self.fps,
        )

    def wandb(self, eval_model_output, title=None):
        eval_model_output = rearrange(
            eval_model_output.cpu(),
            "(t h w) c -> t c h w",
            t=self.dims[0],
            h=self.dims[1],
            w=self.dims[2],
        )
        transform = v2.Compose(
            self.inverse_normalize_transform +
            [
                v2.ToDtype(torch.uint8, scale=True),
            ]
        )
        eval_model_output = transform(eval_model_output)
        return {
            "video": wandb.Video(eval_model_output, fps=self.fps, caption=title),
        }


class VideoAudioFitting(Dataset):
    def __init__(self, video_path=None, double_time=False, subsample_rate=1, normalize=False):
        super().__init__()
        # if video_path is None:
        #     video_path = bigbuckbunny()
        video_audio = torch.load(video_path)

        self.normalize = normalize
        self.normalize_transform = [v2.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))] if normalize else []
        self.inverse_normalize_transform = [v2.Normalize(torch.Tensor([-1.0]), torch.Tensor([2.0]))] if normalize else []
        transform = v2.Compose(
            [
                v2.ToDtype(torch.float32, scale=True),
            ]
            + self.normalize_transform
        )
        self.video = transform(video_audio[0])
        self.audio = video_audio[1].T
        # self.audio = video_audio[1].sum(0, keepdims=True).T

        self.fps = video_audio[2]["video_fps"]
        self.audio_rate = video_audio[2]["audio_fps"]

        # Subsample
        self.subsample_rate = subsample_rate
        self.video = self.video[:, :: self.subsample_rate, :: self.subsample_rate]

        self.dims = self.video.shape[:3]
        self.num_channels = self.video.shape[3]
        self.audio_dims = self.audio.shape[0]
        self.num_audio_channels = self.audio.shape[1]
        self.video = rearrange(self.video, "t h w c -> (t h w) c")
        self.video = torch.cat(
            [self.video, torch.zeros(self.video.shape[0], self.num_audio_channels)],
            dim=-1,
        )
        self.audio = torch.cat(
            [torch.zeros(self.audio_dims, self.num_channels), self.audio], dim=-1
        )
        if normalize:
            self._audio_mean = self.audio.mean()
            self._audio_std = self.audio.std()
        else:
            self._audio_mean = 0.0
            self._audio_std = 1.0
        self.audio = (self.audio - self._audio_mean) / self._audio_std

        video_time_range = (-1, 1) if double_time else (-100, 100)
        self.video_coords = make_grid(
            self.dims, coord_range=[video_time_range, (-1, 1), (-1, 1)], batch_size=1
        )[0]

        self.audio_coords = make_grid(
            [self.audio_dims], coord_range=[-100, 100], batch_size=1
        )[0]
        if double_time:
            video_time = make_grid(
                [self.audio_dims], coord_range=[-1, 1], batch_size=1
            )[0]
            self.audio_coords = torch.cat([self.audio_coords, video_time], dim=-1)
        # Add zeros for x-y coordinates
        self.audio_coords = torch.cat(
            [self.audio_coords, torch.zeros(self.audio_coords.shape[0], 2)],
            dim=-1,
        )
        if double_time:
            audio_time = make_grid(
                self.dims, coord_range=[(-100, 100), (-1, 1), (-1, 1)], batch_size=1
            )[0][:, [0]]
            self.video_coords = torch.cat([audio_time, self.video_coords], dim=-1)

        self.video_audio_mask = torch.cat(
            [
                torch.ones(self.video_coords.shape[0], dtype=torch.bool),
                torch.zeros(self.audio_coords.shape[0], dtype=torch.bool),
            ],
            dim=-1,
        )
        self.coords = torch.cat([self.video_coords, self.audio_coords], dim=0)
        self.data = torch.cat([self.video, self.audio], dim=0)

        self.video_to_audio_ratio = 1.0
        # self.video_to_audio_ratio = self.video.shape[0] / self.audio.shape[0]
        print(f"Video to audio ratio: {self.video_to_audio_ratio}")
        # self.video_to_audio_ratio = min(self.video_to_audio_ratio, 10.0)
        # print(f"Video to audio ratio: {self.video_to_audio_ratio}")

        self.input_dim = 4 if double_time else 3
        self.output_dim = self.num_channels + self.num_audio_channels

    def to(self, device):
        self.coords = self.coords.to(device)
        self.data = self.data.to(device)
        self.video_audio_mask = self.video_audio_mask.to(device)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.coords[idx], self.data[idx], self.video_audio_mask[idx]

    def inv_normalize_video(self, video):
        if self.normalize:
            return video * 0.5 + 0.5
        return video

    def inv_normalize_audio(self, audio):
        return audio * self._audio_std + self._audio_mean

    def savefigs(self, eval_model_output, filename):
        video_output = eval_model_output[self.video_audio_mask][:, : self.num_channels]
        audio_output = eval_model_output[~self.video_audio_mask][:, self.num_channels :]

        video_output = rearrange(
            video_output.cpu(),
            "(t h w) c -> t c h w",
            t=self.dims[0],
            h=self.dims[1],
            w=self.dims[2],
        )

        transform = v2.Compose(
            self.inverse_normalize_transform +
            [
                v2.ToDtype(torch.uint8, scale=True),
            ]
        )
        video_output = transform(video_output)
        audio_output = audio_output.cpu() * self._audio_std + self._audio_mean

        # torchvision.io.write_video(
        #     f"{filename}.mp4", video_output, fps=self.fps,
        #     audio_array=audio_output, audio_fps=self.audio_rate,
        #     audio_codec="aac",
        # )
        s = StreamWriter(dst=f"{filename}.mp4")
        s.add_audio_stream(
            sample_rate=self.audio_rate, num_channels=self.num_audio_channels
        )
        s.add_video_stream(
            frame_rate=self.fps, height=self.dims[1], width=self.dims[2], format="rgb24"
        )

        with s.open():
            s.write_audio_chunk(0, audio_output)
            s.write_video_chunk(1, video_output)

    def wandb(self, eval_model_output, title=None):
        # audio_output = eval_model_output[~self.video_audio_mask][:, self.num_channels:]
        # audio_output = audio_output.cpu() * self._audio_std + self._audio_mean
        with tempfile.NamedTemporaryFile() as tmp:
            self.savefigs(eval_model_output, tmp.name)
            return {
                "video": wandb.Video(f"{tmp.name}.mp4", fps=self.fps, caption=title),
                "audio": wandb.Audio(
                    f"{tmp.name}.mp4",
                    sample_rate=self.audio_rate,
                    caption=f"{title}_audio",
                ),
            }
