from collections import OrderedDict
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def exists(val):
    return val is not None


class MetaWrapper(nn.Module):
    def __init__(self, P, decoder):
        super().__init__()
        self.P = P
        self.decoder = decoder
        self.data_type = P.data_type
        self.sampled_coord = None
        self.sampled_index = None
        self.gradncp_coord = None
        self.gradncp_index = None

        if self.data_type == 'img':
            self.width = P.data_size[1]
            self.height = P.data_size[2]
            mgrid = self.shape_to_coords((self.width, self.height))
            mgrid = rearrange(mgrid, 'h w c -> (h w) c')
        elif self.data_type == 'video':
            self.temporal = P.data_size[0]
            self.width = P.data_size[2]
            self.height = P.data_size[3]
            mgrid = self.shape_to_coords((self.temporal, self.width, self.height))
            mgrid = rearrange(mgrid, 't h w c -> t (h w) c')
        else:
            raise NotImplementedError()

        self.register_buffer('grid', mgrid)

    def shape_to_coords(self, spatial_shape):
        if self.data_type == 'img':
            coords = []
            for i in range(len(spatial_shape)):
                coords.append(torch.linspace(-1.0, 1.0, spatial_shape[i]))
            return torch.stack(torch.meshgrid(*coords), dim=-1)
        elif self.data_type == 'video':
            temp_coords = []
            for i in range(spatial_shape[0]):
                coords = []
                for j in range(1, len(spatial_shape)):
                    coords.append(torch.linspace(-1.0, 1.0, spatial_shape[j]))
                temp_coords.append(torch.stack(torch.meshgrid(*coords), dim=-1))
            return torch.stack(temp_coords)

    def get_batch_params(self, params, batch_size):
        if params is None:
            params = OrderedDict()
            for name, param in self.decoder.meta_named_parameters():
                params[name] = param[None, ...].repeat((batch_size,) + (1,) * len(param.shape))
        return params

    def coord_init(self):
        self.sampled_coord = None
        self.sampled_index = None
        self.gradncp_coord = None
        self.gradncp_index = None

    def get_batch_coords(self, inputs=None, params=None):
        if inputs is None and params is None:
            meta_batch_size = 1
        elif inputs is None:
            meta_batch_size = list(params.values())[0].size(0)
        else:
            meta_batch_size = inputs.size(0)

        if self.sampled_coord is None:
            coords = self.grid
        else:
            coords = self.sampled_coord
        if self.P.data_type == 'video':
            video_coords = []
            for i in range(coords.size(0)):
                for j in range(coords.size(1)):
                    video_coords.append(torch.cat([torch.tensor([i]).float().to(coords.device), coords[i][j]]))
            coords = torch.stack(video_coords)
        coords = coords.clone().detach()[None, ...].repeat((meta_batch_size,) + (1,) * len(coords.shape))
        return coords, meta_batch_size

    def forward(self, inputs, params=None, step_inner=-1, step_iter=-1):
        if self.data_type in ['img']:
            return self.forward_image(inputs, params, step_inner=step_inner, step_iter=step_iter)
        elif self.data_type in ['video']:
            return self.forward_video(inputs, params, step_inner=step_inner, step_iter=step_iter)
        else:
            raise NotImplementedError()

    def sample(self, sample_type, task_data, params):
        if sample_type == 'random':
            self.random_sample()
        else:
            raise NotImplementedError()

    def random_sample(self):
        coord_size = self.grid.size(0)  # shape (h * w, c)
        perm = torch.randperm(coord_size)
        self.sampled_index = perm[:int(self.P.data_ratio * coord_size)]
        self.sampled_coord = self.grid[self.sampled_index]
        return self.sampled_coord

    def forward_image(self, inputs=None, params=None, step_inner=-1, step_iter=-1):
        b, c, h, w = inputs.size()
        coords, meta_batch_size = self.get_batch_coords(inputs, params)

        out = self.decoder(coords, params=params)
        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)

        if exists(inputs):
            return F.mse_loss(inputs, out, reduction='none'), out


        out = rearrange(out, 'b c (h w) -> b c h w', h=self.height, w=self.width)
        return out

    def forward_video(self, inputs=None, params=None, step_inner=-1, step_iter=-1):
        b, t, c, h, w = inputs.size()
        coords, meta_batch_size = self.get_batch_coords(inputs, params)

        out = self.decoder(coords, params=params)
        out = rearrange(out, 'b (t h w) c -> b t c h w', t=t, h=h, w=w)

        if exists(inputs):
            if self.sampled_coord is None:
                return F.mse_loss(inputs, out, reduction='none'), out
            else:
                inputs = rearrange(inputs, 'b c h w -> b c (h w)')[:, :, self.sampled_index]
                return F.mse_loss(
                    inputs.view(meta_batch_size, -1), out.view(meta_batch_size, -1), reduction='none'
                ).mean(dim=1), out

        out = rearrange(out, 'b c (h w) -> b c h w', h=self.height, w=self.width)
        return out
