import glob
import re
import os
import random
from pathlib import Path

from torch.utils.data import Dataset, Dataset, ConcatDataset, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import cv2
import numpy as np
from einops import rearrange
# from functorch import make_functional, vmap
from nn import inr

from nn.inr import make_functional
from experiments.data import Batch
from einops.layers.torch import Rearrange


def get_mgrid(sidelen, dim=2):
    """Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int"""
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid


class SineLayer(nn.Module):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(
        self, in_features, out_features, bias=True, is_first=False, omega_0=30
    ):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0,
                )

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):
        # For visualization of activation distributions
        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


IMG_shapes = {"mnist": (28, 28, 1), "cifar": (32, 32, 3)}
SIREN_kwargs = {
    "mnist": {
        "first_omega_0": 30,
        "hidden_features": 32,
        "hidden_layers": 1,
        "hidden_omega_0": 30.0,
        "in_features": 2,
        "out_features": 1,
        "outermost_linear": True,
    },
    "cifar": {
        "first_omega_0": 30,
        "hidden_features": 32,
        "hidden_layers": 1,
        "hidden_omega_0": 30.0,
        "in_features": 2,
        "out_features": 3,
        "outermost_linear": True,
    },
}


class Siren(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features,
        hidden_layers,
        out_features,
        outermost_linear=False,
        first_omega_0=30,
        hidden_omega_0=30.0,
    ):
        super().__init__()

        self.net = []
        self.net.append(
            SineLayer(
                in_features, hidden_features, is_first=True, omega_0=first_omega_0
            )
        )

        for i in range(hidden_layers):
            self.net.append(
                SineLayer(
                    hidden_features,
                    hidden_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(
                    -np.sqrt(6 / hidden_features) / hidden_omega_0,
                    np.sqrt(6 / hidden_features) / hidden_omega_0,
                )

            self.net.append(final_linear)
        else:
            self.net.append(
                SineLayer(
                    hidden_features,
                    out_features,
                    is_first=False,
                    omega_0=hidden_omega_0,
                )
            )

        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        output = self.net(coords)
        return output, coords
    

def get_batch_siren(dset_type):
    siren = Siren(**SIREN_kwargs[dset_type])
    func_model, _ = make_functional(siren)
    mgrid_len = 28 if dset_type == "mnist" else 32
    coords = get_mgrid(mgrid_len, 2).cuda()
    img_shape = IMG_shapes[dset_type]
    def func_inp(p):
        values = func_model(p, coords)[0]
        return torch.permute(values.reshape(*img_shape), (2, 0, 1))
    return torch.vmap(func_inp, (0,))


# def get_batch_siren(dset_type):
#     siren = Siren(**SIREN_kwargs[dset_type])
#     fmodel, params = inr.make_functional(siren)
#     vparams, vshapes = inr.params_to_tensor(params)
#     batch_sirens = torch.vmap(inr.wrap_func(fmodel, vshapes))

#     # func_model, _ = make_functional(siren)
#     # mgrid_len = 28 if dset_type == "mnist" else 32
#     # coords = get_mgrid(mgrid_len, 2).cuda()
#     # img_shape = IMG_shapes[dset_type]
#     # def func_inp(p):
#     #     values = func_model(p, coords)[0]
#     #     return torch.permute(values.reshape(*img_shape), (2, 0, 1))
#     return batch_sirens


class BatchSiren(nn.Module):
    def __init__(self, dset_type, input_init):
        super().__init__()
        # TODO fix hard coded
        inr_module = Siren(**SIREN_kwargs[dset_type])
        fmodel, params = inr.make_functional(inr_module)

        vparams, vshapes = inr.params_to_tensor(params)
        self.sirens = torch.vmap(inr.wrap_func(fmodel, vshapes))

        self.inputs = nn.Parameter(input_init, requires_grad=False)

        # NOTE hard coded maps
        self.reshape_w0 = Rearrange("b i h0 1 -> b (h0 i)")
        self.reshape_w1 = Rearrange("b h0 h1 1 -> b (h1 h0)")
        self.reshape_w2 = Rearrange("b h1 h2 1 -> b (h2 h1)")

        self.reshape_b0 = Rearrange("b h0 1 -> b h0")
        self.reshape_b1 = Rearrange("b h1 1 -> b h1")
        self.reshape_b2 = Rearrange("b h2 1 -> b h2")

    def forward(self, weights, biases):
        params_flat = torch.cat(
            [self.reshape_w0(weights[0]),
            self.reshape_b0(biases[0]),
            self.reshape_w1(weights[1]),
            self.reshape_b1(biases[1]),
            self.reshape_w2(weights[2]),
            self.reshape_b2(biases[2])], dim=-1)

        out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1))
        # out = torch.cat(out, dim=-1)
        return out[0]
        

def load_path_as_image(path):
    state_dict = torch.load(path)
    model = Siren(**SIREN_kwargs["cifar"])
    model.load_state_dict(state_dict)
    model.eval()
    grid = get_mgrid(32, 2)
    out = model(grid)[0]
    img = rearrange(out, "(n m) d -> n m d", n=32)
    return img.detach() * 0.5 + 0.5


def state_dict_to_tensors(state_dict):
    """Converts a state dict into two lists of equal length:
    1. list of weight tensors
    2. list of biases, or None if no bias
    Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...]
    """
    weights, biases = [], []
    keys = list(state_dict.keys())
    i = 0
    while i < len(keys):
        weights.append(state_dict[keys[i]][None])
        i += 1
        assert keys[i].endswith("bias")
        biases.append(state_dict[keys[i]][None])
        i += 1
    return weights, biases


class INRDataset(Dataset):
    def __init__(
        self,
        split,
        path,
        glob_pattern,
        statistics_path=None,
        augmentation=False,
        translation_scale=0.25,
        rotation_degree=45,
        noise_scale=1e-1,
        drop_rate=1e-2,
        resize_scale=0.2,
        pos_scale=0.0,
        quantile_dropout=0.0,
    ):
        self.split = split
        self.path = Path(path)
        self.glob_pattern = glob_pattern
        self.normalize = statistics_path is not None

        self.augmentation = augmentation
        self.translation_scale = translation_scale
        self.rotation_degree = rotation_degree
        self.noise_scale = noise_scale
        self.drop_rate = drop_rate
        self.resize_scale = resize_scale
        self.pos_scale = pos_scale
        self.quantile_dropout = quantile_dropout

        files = list(
            sorted(self.path.glob(self.glob_pattern), key=lambda p: int(p.stem[3:]))
        )  # assumes a naming convention of "net{idx}.pth"
        # self.files = files[split[0] : split[1]]
        self.files = [p for p in files if split[0] <= int(p.stem[3:]) < split[1]]
        self.labels = [
            int(f.parent.name[-2]) for f in self.files
        ]  # assumes a naming convention of "*_{label}s/"

        if self.normalize:
            self.stats = torch.load(statistics_path, map_location="cpu")

    def _normalize(self, weights, biases):
        wm, ws = self.stats["weights"]["mean"], self.stats["weights"]["std"]
        bm, bs = self.stats["biases"]["mean"], self.stats["biases"]["std"]

        weights = tuple((w - m) / s for w, m, s in zip(weights, wm, ws))
        biases = tuple((w - m) / s for w, m, s in zip(biases, bm, bs))

        return weights, biases

    @staticmethod
    def rotation_mat(degree=30.0):
        angle = torch.empty(1).uniform_(-degree, degree)
        angle_rad = angle * (torch.pi / 180)
        rotation_matrix = torch.tensor(
            [
                [torch.cos(angle_rad), -torch.sin(angle_rad)],
                [torch.sin(angle_rad), torch.cos(angle_rad)],
            ]
        )
        return rotation_matrix

    def _augment(self, weights, biases):
        new_weights, new_biases = list(weights), list(biases)
        # translation
        translation = torch.empty(weights[0].shape[0]).uniform_(
            -self.translation_scale, self.translation_scale
        )
        order = random.sample(range(1, len(weights)), 1)[0]
        bias_res = translation
        i = 0
        for i in range(order):
            bias_res = bias_res @ weights[i]

        new_biases[i] += bias_res

        # rotation
        if new_weights[0].shape[0] == 2:
            rot_mat = self.rotation_mat(self.rotation_degree)
            new_weights[0] = rot_mat @ new_weights[0]

        # noise
        new_weights = [w + w.std() * self.noise_scale for w in new_weights]
        new_biases = [
            b + b.std() * self.noise_scale if b.shape[0] > 1 else b for b in new_biases
        ]

        # dropout
        new_weights = [F.dropout(w, p=self.drop_rate) for w in new_weights]
        new_biases = [F.dropout(w, p=self.drop_rate) for w in new_biases]

        # scale
        # todo: can also apply to deeper layers
        rand_scale = 1 + (torch.rand(1).item() - 0.5) * 2 * self.resize_scale
        new_weights[0] = new_weights[0] * rand_scale

        # positive scale
        if self.pos_scale > 0:
            for i in range(len(new_weights) - 1):
                # todo: we do a lot of duplicated stuff here
                out_dim = new_biases[i].shape[0]
                scale = torch.from_numpy(
                    np.random.uniform(
                        1 - self.pos_scale, 1 + self.pos_scale, out_dim
                    ).astype(np.float32)
                )
                inv_scale = 1.0 / scale
                new_weights[i] = new_weights[i] * scale
                new_biases[i] = new_biases[i] * scale
                new_weights[i + 1] = (new_weights[i + 1].T * inv_scale).T

        if self.quantile_dropout > 0:
            do_q = torch.empty(1).uniform_(0, self.quantile_dropout)
            q = torch.quantile(
                torch.cat([v.flatten().abs() for v in new_weights + new_biases]), q=do_q
            )
            new_weights = [torch.where(w.abs() < q, 0, w) for w in new_weights]
            new_biases = [torch.where(w.abs() < q, 0, w) for w in new_biases]

        return tuple(new_weights), tuple(new_biases)

    def __getitem__(self, idx):
        state_dict = torch.load(self.files[idx])
        weights = tuple(
            [v.permute(1, 0) for w, v in state_dict.items() if "weight" in w]
        )
        biases = tuple([v for w, v in state_dict.items() if "bias" in w])

        if self.augmentation:
            weights, biases = self._augment(weights, biases)

        weights = tuple([w.unsqueeze(-1) for w in weights])
        biases = tuple([b.unsqueeze(-1) for b in biases])

        if self.normalize:
            weights, biases = self._normalize(weights, biases)

        return Batch(weights=weights, biases=biases, label=self.labels[idx])

    def __len__(self):
        return len(self.files)


class SirenDataset(Dataset):
    def __init__(self, data_path, prefix="randinit_test"):
        idx_pattern = r"net(\d+)\.pth"
        label_pattern = r"_(\d)s"
        self.idx_to_path = {}
        self.idx_to_label = {}
        for siren_path in glob.glob(os.path.join(data_path, f"{prefix}_*/*.pth")):
            idx = int(re.search(idx_pattern, siren_path).group(1))
            self.idx_to_path[idx] = siren_path
            label = int(re.search(label_pattern, siren_path).group(1))
            self.idx_to_label[idx] = label
        assert sorted(list(self.idx_to_path.keys())) == list(
            range(len(self.idx_to_path))
        )

    def __getitem__(self, idx):
        sd = torch.load(self.idx_to_path[idx])
        weights, biases = state_dict_to_tensors(sd)
        return (weights, biases), self.idx_to_label[idx]

    def __len__(self):
        return len(self.idx_to_path)


DEF_TFM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])),
    ]
)


def increase_contrast(img):
    # https://stackoverflow.com/questions/39308030/how-do-i-increase-the-contrast-of-an-image-in-python-opencv
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    l_channel, a, b = cv2.split(lab)
    # Applying CLAHE to L-channel
    # feel free to try different values for the limit and grid size:
    clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(3, 3))
    cl = clahe.apply(l_channel)
    # merge the CLAHE enhanced L-channel with the a and b channel
    limg = cv2.merge((cl, a, b))
    # Converting image from LAB Color model to BGR color spcae
    enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
    return enhanced_img


class SirenAndOriginalDataset(Dataset):
    def __init__(self, siren_path, siren_prefix, data_path, data_tfm=DEF_TFM):
        self.siren_dset = SirenDataset(siren_path, prefix=siren_prefix)
        if "mnist" in siren_path:
            self.data_type = "mnist"
            print("Loading MNIST")
            MNIST_train = MNIST(
                data_path, transform=data_tfm, train=True, download=True
            )
            MNIST_test = MNIST(
                data_path, transform=data_tfm, train=False, download=True
            )
            self.dset = Subset(
                ConcatDataset([MNIST_train, MNIST_test]), range(len(self.siren_dset))
            )
        else:
            self.data_type = "cifar"
            print("Loading CIFAR10")
            CIFAR_train = CIFAR10(
                data_path, transform=data_tfm, train=True, download=True
            )
            CIFAR_test = CIFAR10(
                data_path, transform=data_tfm, train=False, download=True
            )
            self.dset = ConcatDataset([CIFAR_train, CIFAR_test])
        assert len(self.siren_dset) == len(
            self.dset
        ), f"{len(self.siren_dset)} != {len(self.dset)}"

    def __len__(self):
        return len(self.dset)

    def __getitem__(self, idx):
        params, siren_label = self.siren_dset[idx]
        img, data_label = self.dset[idx]
        assert siren_label == data_label
        return params, img, data_label
