from abc import ABC, abstractmethod
import pickle
from typing import Literal, Any

import numpy as np
import torch
import torch.nn.functional as F
from torch import FloatTensor
from torch.utils.data import DataLoader
from torchdata.datapipes.iter import IterDataPipe

from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.seed import isolate_rng
import warnings


from data.utils import *
from data.utils_fastfood import FastfoodWrapper
from models.utils import MLP

warnings.filterwarnings("ignore", message=".*does not have many workers.*")
warnings.filterwarnings("ignore", message=".*`IterableDataset` has `__len__` defined.*")

import itertools

def init_weights(m, std=1.0):
    if isinstance(m, BatchedLinear):
        torch.nn.init.normal_(m.weight, std=std)
        torch.nn.init.normal_(m.bias, std=std)


class ExpertsRegressionDataModule(LightningDataModule):
    def __init__(
        self,
        n_experts: int,
        ratio: float,
        x_dim: int,
        y_dim: int,
        min_context: int,
        max_context: int,
        batch_size: int = 128,
        train_size: int = 10000,
        finite_train: bool = False,
        val_size: int = 100,
        noise: float = 0.5,
        context_style: str = "same",
        ood_styles: tuple[str] | None = ["far", "wide"],
        ood_intensity: float = 3.0,
        kind_kwargs: dict[str, Any] = {},
    ):
        super().__init__()
        self.save_hyperparameters()

        experts = range(n_experts)
        permutations = torch.tensor([list(p) for p in itertools.product(experts, repeat=5)])
        permutations = permutations[torch.randperm(permutations.size()[0])].long()

        self.train_data = ExpertsRegressionDataset(
            x_dim=x_dim,
            y_dim=y_dim,
            min_context=min_context,
            max_context=max_context,
            batch_size=batch_size,
            data_size=train_size,
            noise=noise,
            context_style=context_style,
            finite=finite_train,
            n_experts=n_experts,
            permutations=permutations,
            ratio=ratio,
            **kind_kwargs,
        )
        self.val_data = {
            "iid": ExpertsRegressionDataset(
                x_dim=x_dim,
                y_dim=y_dim,
                min_context=min_context,
                max_context=max_context,
                batch_size=val_size,
                data_size=val_size,
                noise=noise,
                context_style=context_style,
                finite=True,
                ood=False,
                n_experts=n_experts,
                permutations=permutations,
                ratio=ratio,
                **kind_kwargs,
            )
        }
        if ood_styles is not None:
            for style in ood_styles:
                self.val_data[style] = ExpertsRegressionDataset(
                    x_dim=x_dim,
                    y_dim=y_dim,
                    min_context=min_context,
                    max_context=max_context,
                    batch_size=val_size,
                    data_size=val_size,
                    noise=noise,
                    context_style=context_style,
                    finite=True,
                    ood=True,
                    ood_style=style,
                    ood_intensity=ood_intensity,
                    n_experts=n_experts,
                    permutations=permutations,
                    ratio=ratio,
                    **kind_kwargs,
                )

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=None)

    def val_dataloader(self):
        return [DataLoader(v, batch_size=None) for v in self.val_data.values()]


class ExpertsRegressionDataset(ABC, IterDataPipe):
    def __init__(
        self,
        x_dim: int,
        y_dim: int,
        data_size: int,
        min_context: int,
        max_context: int,
        batch_size: int = 128,
        noise: float = 0.0,
        ood: bool = False,
        context_style: str = "same",
        ood_style: str = "far",
        ood_intensity: float = 3.0,
        finite: bool = False,
        n_experts: int = 5,
        permutations = None,
        ratio = None
    ) -> None:
        super().__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.min_context = min_context
        self.max_context = max_context
        self.batch_size = batch_size
        self.data_size = data_size
        self.noise = noise
        self.ood = ood
        self.context_style = context_style
        self.ood_style = ood_style
        self.ood_intensity = ood_intensity
        self.finite = finite
        self.n_params = n_experts
        self.n_experts = n_experts
        self.permutations = permutations
        self.n_eval = int((1. - ratio) * len(permutations))

        with isolate_rng():
            torch.manual_seed(0)
            self.models = nn.Linear(self.x_dim, self.y_dim * self.n_experts)
            self.models.weight.data.normal_()
            self.models.bias.data.normal_()

        if self.finite:
            self.generate_finite_data()

    def generate_finite_data(self):
        self.fixed_x_c = torch.randn(self.data_size, self.max_context, self.x_dim)
        self.fixed_x_q = torch.randn(self.data_size, self.max_context, self.x_dim)
        if self.ood:
            if self.ood_style == "wide":
                self.fixed_x_q *= self.ood_intensity
            elif self.ood_style == "far":
                direction = torch.randn_like(self.fixed_x_q)
                self.fixed_x_q = (
                    self.fixed_x_q * 0.1
                    + self.ood_intensity
                    * direction
                    / direction.norm(dim=-1, keepdim=True)
                )
        self.fixed_params = self.sample_function_params(n=self.data_size, ood=True)
        self.fixed_y_c = self.function(self.fixed_x_c, self.fixed_params)
        self.fixed_y_q = self.function(self.fixed_x_q, self.fixed_params)
        self.fixed_y_c += self.noise * torch.randn_like(self.fixed_y_c)
        self.fixed_y_q += self.noise * torch.randn_like(self.fixed_y_q)

    def sample_finite_batch(self, n_context, return_vis=False):
        x_c = self.fixed_x_c[: self.batch_size, :n_context]
        x_q = self.fixed_x_q[: self.batch_size, :n_context]
        y_c = self.fixed_y_c[: self.batch_size, :n_context]
        y_q = self.fixed_y_q[: self.batch_size, :n_context]
        if self.fixed_params is not None:
            params = self.fixed_params[: self.batch_size]
            return (x_c, y_c), (x_q, y_q), params.float()
        else:
            if return_vis:
                return (
                    (x_c, y_c),
                    (x_q, y_q),
                    None,
                    (
                        self.fixed_x_vis[: self.batch_size],
                        self.fixed_y_vis[: self.batch_size],
                    ),
                )
            else:
                return (x_c, y_c), (x_q, y_q), None

    def sample_x(self, n_context):
        x_c = torch.randn(self.batch_size, n_context, self.x_dim)
        if self.context_style == "same":
            x_q = torch.randn(self.batch_size, n_context, self.x_dim)
        elif self.context_style == "near":
            x_q = x_c + 0.1 * torch.randn_like(x_c)
        else:
            raise ValueError("Invalid context style")
        return x_c, x_q

    def sample_function_params(self, n: int = None, ood: bool = False) -> FloatTensor:
        n = n if n is not None else self.batch_size

        if ood:
            idx = torch.tensor(np.random.choice(len(self.permutations), n))
            # idx = torch.tensor(len(self.permutations) - self.n_eval + np.random.choice(self.n_eval, n))
        else:
            idx = torch.tensor(np.random.choice(len(self.permutations) - self.n_eval, n))
        
        return self.permutations[idx]

    def function_params(self) -> FloatTensor:
        return self.sample_function_params().float()

    @torch.no_grad()
    def model(self, x, params):
        y = x
        for idx in params.long():
            out = self.models(y).view(y.size(0), self.y_dim, self.n_experts)
            idx = F.one_hot(idx, self.n_experts)
            y = torch.tanh((out * idx.view(1, 1, self.n_experts)).sum(-1))
        
        return y

    def function(self, x, params) -> FloatTensor:
        # x: (bsz, n_samples, x_dim)
        # params: (bsz, ...) parameters of the function
        # returns y: (bsz, n_samples, y_dim)
        out = torch.vmap(self.model, in_dims=(0, 0), out_dims=0)(x, params)
        return out

    def get_batch(self, n_context=None):
        if n_context is None:
            n_context = np.random.randint(self.min_context, self.max_context + 1)

        if self.finite:
            n_context = (self.min_context + self.max_context) // 2
            return self.sample_finite_batch(n_context)

        x_c, x_q = self.sample_x(n_context)
        params = self.function_params()
        y_c, y_q = self.function(x_c, params), self.function(x_q, params)
        y_c += self.noise * torch.randn_like(y_c)
        y_q += self.noise * torch.randn_like(y_q)
        return (x_c, y_c), (x_q, y_q), params

    def __len__(self):
        return self.data_size // self.batch_size

    def __iter__(self):
        for _ in range(len(self)):
            yield self.get_batch()