import math

import torch
import numpy as np
from transformers import AutoTokenizer
import random

from torch.nn import CrossEntropyLoss


def squared_error(ys_pred, ys):
    return (ys - ys_pred).square()


def mean_squared_error(ys_pred, ys):
    return (ys - ys_pred).square().mean()


def accuracy(ys_pred, ys):
    return (ys == ys_pred.sign()).float()


def cross_entropy(ys_pred, ys):
    output = sigmoid(ys_pred)
    target = (ys + 1) / 2
    return bce_loss(output, target)


def cross_entropy_lm(ys_pred, ys):
    loss_fct = CrossEntropyLoss()
    ys_pred = ys_pred[..., :-1, :].contiguous()
    ys = ys[..., 1:].contiguous()
    ys_pred = ys_pred.view(-1, ys_pred.size(-1))
    loss = loss_fct(ys_pred, ys.view(-1))
    return loss


def cross_entropy_zero_one(ys_pred, ys):
    output = sigmoid(ys_pred)
    target = ys
    return bce_loss(output, target)


def cross_entropy_no_reduction(ys_pred, ys):
    output = sigmoid(ys_pred)
    target = ys
    return bce_loss_no_reduce(output, target)

sigmoid = torch.nn.Sigmoid()
bce_loss = torch.nn.BCELoss()
bce_loss_no_reduce = torch.nn.BCELoss(reduction="none")


class Task:
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
        self.n_dims = n_dims
        self.b_size = batch_size
        self.pool_dict = pool_dict
        self.seeds = seeds
        assert pool_dict is None or seeds is None

    def evaluate(self, xs):
        raise NotImplementedError

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        raise NotImplementedError

    @staticmethod
    def get_training_metric():
        raise NotImplementedError


def get_task_sampler(
    task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs
):
    task_names_to_classes = {
        "linear_regression": LinearRegression,
        "probabilistic_logistic_regression": ProbabilisticLogisticRegression,
        "sparse_linear_regression": SparseLinearRegression,
        "linear_classification": LinearClassification,
        "noisy_linear_regression": NoisyLinearRegression,
        "quadratic_regression": QuadraticRegression,
        "probabilistic_tanh": ProbabilisticTanh,
        "relu_2nn_regression": Relu2nnRegression,
        "decision_tree": DecisionTree,
        "crf": CRF,
        "crf_ising": CRF_ISING,
        "three_nodes": ThreeNodeTree,
        "nl": NLSyntheticTask,
        "nlreal": NLRealTask,
        "nladap": NLSyntheticTaskAdaptor,
    }
    if task_name in task_names_to_classes:
        task_cls = task_names_to_classes[task_name]
        if num_tasks is not None:
            if pool_dict is not None:
                raise ValueError("Either pool_dict or num_tasks should be None.")
            pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs)
        return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs)
    else:
        print("Unknown task")
        raise NotImplementedError


class LinearRegression(Task):
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, scale=1):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(LinearRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale

        if pool_dict is None and seeds is None:
            self.w_b = torch.randn(self.b_size, self.n_dims, 1)
        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.w_b[i] = torch.randn(self.n_dims, 1, generator=generator)
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def evaluate(self, xs_b):
        w_b = self.w_b.to(xs_b.device)
        ys_b = self.scale * (xs_b @ w_b)[:, :, 0]
        return ys_b

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error


class NLRealTask(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
        default_word="null",
        n_points=None,
        tokenizer_name="EleutherAI/gpt-neo-125M", 
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(NLRealTask, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.weight_multiplier = weight_multiplier
        self.n_dims = n_dims
        self.n_points = n_points
        self.negative_token_id = 4633
        self.positive_token_id = 3967

        self.label_words = {0: "negative", 1: "positive"}
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self._tokenizer.pad_token = self._tokenizer.eos_token

        if pool_dict is None and seeds is None:
            if variable_noise:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1)
                self.w_b = self.w_b * torch.randint(1, 11, (self.b_size, 1, 1))
            else:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                if variable_noise:
                    self.w_b[i] = torch.randn(
                        self.n_dims, 1, generator=generator
                    ) * torch.randint(1, 11, (self.b_size, 1, 1))

                else:
                    self.w_b[i] = (
                        torch.randn(self.n_dims, 1, generator=generator)
                        * weight_multiplier
                    )
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def _construct_sequences(self, xs_b, ys_b):
        ### xs_b is a list of a list of sentences. the internal list represents a single sequence in a batch
        batch = []
        for b_idx, sequence in enumerate(xs_b):
            batch_entry = []
            batch_sequence = ""
            for p_idx in range(len(sequence)):
                sentence = sequence[p_idx]
                sentence = (
                    sentence.replace(":", "-")
                    .replace(",", ".")
                    .replace("positive", "pos")
                    .replace("negative", "neg")
                )
                if ys_b[b_idx, p_idx].item() == 0:
                    batch_sequence += f"{sentence[0:100]}: negative , "
                else:
                    batch_sequence += f"{sentence[0:100]}: positive , "
            batch_sequence = batch_sequence[:-2].strip()
            batch.append(batch_sequence)
        tokenized_batch = self._tokenizer(
            batch, padding=True, truncation=True, return_tensors="pt"
        ).input_ids
        return tokenized_batch

    def evaluate(self, xs_b, sentence_ids=None):
        w_b = self.w_b.to(xs_b.device)
        probability = torch.sigmoid(self.scale * (xs_b @ w_b)[:, :, 0])
        ys_b = torch.bernoulli(probability)
        nl_batch = self._construct_sequences(sentence_ids, ys_b)
        labels = torch.where(
            (nl_batch == 3967) | (nl_batch == 4633), nl_batch, torch.tensor(-100)
        )
        total_samples = len(torch.where(labels[0] != -100)[0])
        ys_b = ys_b[:, :total_samples]
        return ys_b, w_b.detach(), nl_batch

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one  


class NLSyntheticTaskAdaptor(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
        default_word="null",
        n_points=None,
        tokenizer_name="EleutherAI/gpt-neo-125M", 
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(NLSyntheticTaskAdaptor, self).__init__(
            n_dims, batch_size, pool_dict, seeds
        )
        self.scale = scale
        self.weight_multiplier = weight_multiplier
        self.n_dims = n_dims
        self.n_points = n_points
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.positive_token_id_space = self.tokenizer(" positive").input_ids[0]
        self.negative_token_id_space = self.tokenizer(" negative").input_ids[0]
        self.valid_words = []
        for x in self.tokenizer.vocab.keys():
            if (
                len(self.tokenizer(f" {x} sports : ").input_ids) == 4
                and len(self.tokenizer(f"{x} sports : ").input_ids) == 4
            ):
                self.valid_words.append(x)

        if "pythia" not in tokenizer_name:
            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "spell",
                "bar",
                "jump",
                "park",
                "run",
                "hill",
                "fast",
                "slow",
                "talk",
                "wallet",
                "orange",
                "apple",
                "ball",
                "cat",
            ]

        else:
            self.words = [
                "love",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "bar",
                "jump",
                "park",
                "run",
                "hill",
                "fast",
                "slow",
                "talk",
                "orange",
                "apple",
                "ball",
                "cat",
            ]

        self.label_words = {0: "negative", 1: "positive"}
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.default_word = default_word
        self._tokenizer.pad_token = self._tokenizer.eos_token
        self.vocabulary = list(self._tokenizer.vocab.keys())
        self._tokenizer.truncation_side = "left"
        self.model_name = tokenizer_name

        if pool_dict is None and seeds is None:
            if variable_noise:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1)
                self.w_b = self.w_b * torch.randint(1, 6, (self.b_size, 1, 1))
            else:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            

        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                if variable_noise:
                    self.w_b[i] = torch.randn(
                        self.n_dims, 1, generator=generator
                    ) * torch.randint(1, 11, (self.b_size, 1, 1))

                else:
                    self.w_b[i] = (
                        torch.randn(self.n_dims, 1, generator=generator)
                        * weight_multiplier
                    )
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def _construct_sequences(self, xs_b, ys_b):
        batch = []
        for b_idx in range(xs_b.shape[0]):
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])
                    else:
                        word.append(self.default_word)

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " ".join([sample[0], sample[1]])
                input_seq += " "
            tokenized_seq = self._tokenizer(input_seq.strip()).input_ids
            tok_batch.append(tokenized_seq)
        return torch.tensor(tok_batch)

    def _construct_sequences_colon(self, xs_b, ys_b):
        batch = []
     
        for b_idx in range(xs_b.shape[0]):
          
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])
                    else:
                        word.append(self.default_word)

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " : ".join([sample[0], sample[1]])
                input_seq += " , "
               
            tokenized_seq = self._tokenizer(input_seq.strip(" , ")).input_ids
            tok_batch.append(tokenized_seq)
        return torch.tensor(tok_batch)

    def _construct_sequences_drop_null(self, xs_b, ys_b):
        batch = []
        for b_idx in range(xs_b.shape[0]):
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                if len(word) == 0:
                    word.append("null")
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        input_strings = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " : ".join([sample[0], sample[1]])
                input_seq += " , "
            input_strings.append(input_seq.strip(" , "))

        if "pythia" in self.model_name:
            additional_tokens = 2
        else:
            additional_tokens = 3
        max_length = (self.n_dims + additional_tokens) * self.n_points
        tokenized_batch = self._tokenizer(
            input_strings, padding="max_length", max_length=max_length, truncation=True
        ).input_ids
        return torch.tensor(tokenized_batch)

    def evaluate(self, xs_b, w_b=None):
        if w_b is not None:
            self.w_b = w_b * self.weight_multiplier

        w_b = self.w_b.to(xs_b.device)
        probability = torch.sigmoid(self.scale * (xs_b @ w_b)[:, :, 0])
        ys_b = torch.bernoulli(probability)
        nl_batch = self._construct_sequences_colon(xs_b, ys_b)

        return ys_b, w_b.detach(), nl_batch

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class NLSyntheticTask(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
        default_word="null",
        n_points=None,
        tokenizer_name=None    
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(NLSyntheticTask, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.weight_multiplier = weight_multiplier
        self.n_dims = n_dims
        self.n_points = n_points
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.positive_token_id_space = self.tokenizer(" positive").input_ids[0]
        self.negative_token_id_space = self.tokenizer(" negative").input_ids[0]
        self.valid_words = []
        for x in self.tokenizer.vocab.keys():
            if (
                len(self.tokenizer(f" {x} sports : ").input_ids) == 4
                and len(self.tokenizer(f"{x} sports : ").input_ids) == 4
            ):
                self.valid_words.append(x)

        if "pythia" not in tokenizer_name:
            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
            ]
            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "spell",
                "bar",
            ]

            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "spell",
                "bar",
                "jump",
                "park",
                "run",
                "hill",
                "fast",
                "slow",
                "talk",
                "wallet",
                "orange",
                "apple",
                "ball",
                "cat",
            ]

        else:
            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
            ]
            self.words = [
                "sports",
                "love",
                "hate",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "spell",
                "bar",
            ]

            self.words = [
                "love",
                "car",
                "school",
                "family",
                "work",
                "sleep",
                "water",
                "tree",
                "fox",
                "train",
                "random",
                "movie",
                "music",
                "book",
                "play",
                "house",
                "bar",
                "jump",
                "park",
                "run",
                "hill",
                "fast",
                "slow",
                "talk",
                "orange",
                "apple",
                "ball",
                "cat",
            ]

        self.label_words = {0: "negative", 1: "positive"}
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.default_word = default_word
        self._tokenizer.pad_token = self._tokenizer.eos_token
        self.vocabulary = list(self._tokenizer.vocab.keys())
        self._tokenizer.truncation_side = "left"
        self.model_name = tokenizer_name

        if pool_dict is None and seeds is None:
            if variable_noise:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1)
                # multiply each row of w_b by a weight sampled from [1, 2, ... 10]
                self.w_b = self.w_b * torch.randint(1, 6, (self.b_size, 1, 1))
            else:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                if variable_noise:
                    self.w_b[i] = torch.randn(
                        self.n_dims, 1, generator=generator
                    ) * torch.randint(1, 11, (self.b_size, 1, 1))

                else:
                    self.w_b[i] = (
                        torch.randn(self.n_dims, 1, generator=generator)
                        * weight_multiplier
                    )
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def _construct_sequences(self, xs_b, ys_b):
        batch = []
        for b_idx in range(xs_b.shape[0]):
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])
                    else:
                        word.append(self.default_word)

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " ".join([sample[0], sample[1]])
                input_seq += " "
            tokenized_seq = self._tokenizer(input_seq.strip()).input_ids
            tok_batch.append(tokenized_seq)
        return torch.tensor(tok_batch)

    def _construct_sequences_colon(self, xs_b, ys_b):
        batch = []
      
        for b_idx in range(xs_b.shape[0]):
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])
                    else:
                        word.append(self.default_word)

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " : ".join([sample[0], sample[1]])
                input_seq += " , "
            tokenized_seq = self._tokenizer(input_seq.strip(" , ")).input_ids
            tok_batch.append(tokenized_seq)
        return torch.tensor(tok_batch)

    def _construct_sequences_drop_null(self, xs_b, ys_b):
        batch = []
        for b_idx in range(xs_b.shape[0]):
            sequence = []
            for p_idx in range(xs_b.shape[1]):
                word = []
                for d_idx in range(xs_b.shape[2]):
                    if xs_b[b_idx, p_idx, d_idx] == 1:
                        word.append(self.words[d_idx])

                if ys_b[b_idx, p_idx].item() == 0:
                    label_str = self.label_words[0]
                else:
                    label_str = self.label_words[1]
                if len(word) == 0:
                    word.append("null")
                sequence.append(((" ".join(word)), label_str))
            batch.append(sequence)

        tok_batch = []
        input_strings = []
        for sequence in batch:
            input_seq = ""
            for sample in sequence:
                input_seq += " : ".join([sample[0], sample[1]])
                input_seq += " , "
            input_strings.append(input_seq.strip(" , "))

        if "pythia" in self.model_name:
            additional_tokens = 2
        else:
            additional_tokens = 3
        max_length = (self.n_dims + additional_tokens) * self.n_points
        tokenized_batch = self._tokenizer(
            input_strings, padding="max_length", max_length=max_length, truncation=True
        ).input_ids
        # ok_batch.append(tokenized_seq)
        return torch.tensor(tokenized_batch)

    def evaluate(self, xs_b, w_b=None):
        if w_b is not None:
            self.w_b = w_b * self.weight_multiplier

    
        w_b = self.w_b.to(xs_b.device)
        probability = torch.sigmoid(self.scale * (xs_b @ w_b)[:, :, 0])
        ys_b = torch.bernoulli(probability)
        nl_batch = self._construct_sequences_colon(xs_b, ys_b)

        ys_b = torch.where(
            (nl_batch == self.positive_token_id_space)
            | (nl_batch == self.negative_token_id_space),
            nl_batch,
            torch.tensor(-100),
        )
        return ys_b, w_b.detach(), nl_batch

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_lm


class ProbabilisticLogisticRegression(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
        n_points=None,
        tokenizer_name=None,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ProbabilisticLogisticRegression, self).__init__(
            n_dims, batch_size, pool_dict, seeds
        )
        self.scale = scale
        self.weight_multiplier = weight_multiplier

        if pool_dict is None and seeds is None:
            if variable_noise:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1)
                # multiply each row of w_b by a weight sampled from [1, 2, ... 10]
                self.w_b = self.w_b * torch.randint(1, weight_multiplier+1, (self.b_size, 1, 1))
            else:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                if variable_noise:
                    self.w_b[i] = torch.randn(
                        self.n_dims, 1, generator=generator
                    ) * torch.randint(1, 11, (self.b_size, 1, 1))

                else:
                    self.w_b[i] = (
                        torch.randn(self.n_dims, 1, generator=generator)
                        * weight_multiplier
                    )
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def evaluate(self, xs_b, w_b=None):
        if w_b is not None:
            self.w_b = w_b * self.weight_multiplier

        w_b = self.w_b.to(xs_b.device)

        probability = torch.sigmoid(self.scale * (xs_b @ w_b)[:, :, 0])
        ys_b = torch.bernoulli(probability)
        return ys_b, w_b

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class ProbabilisticTanh(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ProbabilisticTanh, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale

        if pool_dict is None and seeds is None:
            if variable_noise:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1)
                # multiply each row of w_b by a weight sampled from [1, 2, ... 10]
                self.w_b = self.w_b * torch.randint(1, 11, (self.b_size, 1, 1))
            else:
                self.w_b = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b = torch.zeros(self.b_size, self.n_dims, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                if variable_noise:
                    self.w_b[i] = torch.randn(
                        self.n_dims, 1, generator=generator
                    ) * torch.randint(1, 11, (self.b_size, 1, 1))

                else:
                    self.w_b[i] = (
                        torch.randn(self.n_dims, 1, generator=generator)
                        * weight_multiplier
                    )
        else:
            assert "w" in pool_dict
            indices = torch.randperm(len(pool_dict["w"]))[:batch_size]
            self.w_b = pool_dict["w"][indices]

    def evaluate(self, xs_b, w_b=None):
        w_b = self.w_b.to(xs_b.device)
        probability = (torch.tanh(self.scale * (xs_b @ w_b)[:, :, 0]) + 1) / 2
        ys_b = torch.bernoulli(probability)
        return ys_b, w_b

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class CRF(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(CRF, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        if pool_dict is None and seeds is None:
            self.w_b1 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            self.w_b2 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            self.w_b3 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b1 = torch.zeros(self.b_size, self.n_dims, 1)
            self.w_b2 = torch.zeros(self.b_size, self.n_dims, 1)
            self.w_b3 = torch.zeros(self.b_size, self.n_dims, 1)

            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.w_b1[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
                self.w_b2[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
                self.w_b3[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
        else:
            raise ValueError("CRF does not support pool_dict")

    def evaluate(self, xs_b):
        w_b1 = self.w_b1.to(xs_b.device)
        probability = torch.sigmoid(self.scale * (xs_b @ w_b1)[:, :, 0])
        ys_b1 = torch.bernoulli(probability)

        w_b2 = self.w_b2.to(xs_b.device)
        w_b3 = self.w_b3.to(xs_b.device)

        y_21 = torch.bernoulli(torch.sigmoid(self.scale * (xs_b @ w_b2)[:, :, 0]))
        y_22 = torch.bernoulli(torch.sigmoid(self.scale * (xs_b @ w_b3)[:, :, 0]))

        ys_b2 = torch.where(ys_b1 == 1, y_21, y_22)
        return torch.stack([ys_b1, ys_b2], axis=0), [w_b1, w_b2, w_b3]

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class CRF_ISING(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(CRF_ISING, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        if pool_dict is None and seeds is None:
            self.w_b1 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            self.w_b2 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b1 = torch.zeros(self.b_size, self.n_dims, 1)
            self.w_b2 = torch.zeros(self.b_size, self.n_dims, 1)

            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.w_b1[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
                self.w_b2[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )

        else:
            raise ValueError("CRF_ISING does not support pool_dict")

    ## re-write _normalized_distribution where p1 and p2 are 3D tensors
    def _normalized_distribution_hard(self, p1, p2):
        # Calculate the probabilities of y1=1 and y2=1
        p_y1_1_y2_1 = p1 + p2 - torch.sqrt(p1 * p2)
        # Calculate the probabilities of y1=0 and y2=0
        p_y1_0_y2_0 = (1 - p1) + (1 - p2) - torch.sqrt((1 - p1) * (1 - p2))
        # Calculate the probabilities of y1=1 and y2=0
        p_y1_1_y2_0 = p1 + (1 - p2) - torch.sqrt((p1) * (1 - p2))
        # Calculate the probabilities of y1=0 and y2=1
        p_y1_0_y2_1 = (1 - p1) + p2 - torch.sqrt((1 - p1) * (p2))
        # Calculate the normalization constant Z
        Z = p_y1_1_y2_1 + p_y1_0_y2_0 + p_y1_1_y2_0 + p_y1_0_y2_1
        # R turn a tuple of all four probabilities
        return p_y1_1_y2_1 / Z, p_y1_0_y2_0 / Z, p_y1_1_y2_0 / Z, p_y1_0_y2_1 / Z

    def _normalized_distribution_simple(self, p1, p2):
        # Calculate the probabilities of y1=1 and y2=1
        p_y1_1_y2_1 = p1 * p2
        # Calculate the probabilities of y1=0 and y2=0
        p_y1_0_y2_0 = (1 - p1) * (1 - p2)
        # Calculate the probabilities of y1=1 and y2=0
        p_y1_1_y2_0 = (p1) * (1 - p2)
        # Calculate the probabilities of y1=0 and y2=1
        p_y1_0_y2_1 = (1 - p1) * (p2)
        # Calculate the normalization constant Z
        Z = p_y1_1_y2_1 + p_y1_0_y2_0 + p_y1_1_y2_0 + p_y1_0_y2_1
        # R turn a tuple of all four probabilities
        return p_y1_1_y2_1 / Z, p_y1_0_y2_0 / Z, p_y1_1_y2_0 / Z, p_y1_0_y2_1 / Z

    def _normalized_distribution_easy(self, p1, p2):
        # Calculate the probabilities of y1=1 and y2=1
        p_y1_1_y2_1 = p1 + p2
        # Calculate the probabilities of y1=0 and y2=0
        p_y1_0_y2_0 = (1 - p1) + (1 - p2)
        # Calculate the probabilities of y1=1 and y2=0
        p_y1_1_y2_0 = (p1) + (1 - p2)
        # Calculate the probabilities of y1=0 and y2=1
        p_y1_0_y2_1 = (1 - p1) + (p2)
        # Calculate the normalization constant Z
        Z = p_y1_1_y2_1 + p_y1_0_y2_0 + p_y1_1_y2_0 + p_y1_0_y2_1
        # R turn a tuple of all four probabilities
        return p_y1_1_y2_1 / Z, p_y1_0_y2_0 / Z, p_y1_1_y2_0 / Z, p_y1_0_y2_1 / Z

    # NOTE: formerly medium_2
    def _normalized_distribution_medium(self, p1, p2):
        # Calculate the probabilities of y1=1 and y2=1
        p_y1_1_y2_1 = p1 + p2 - p1 * p2
        # Calculate the probabilities of y1=0 and y2=0
        p_y1_0_y2_0 = (1 - p1) + (1 - p2) - (1 - p1) * (1 - p2)
        # Calculate the probabilities of y1=1 and y2=0
        p_y1_1_y2_0 = p1 + (1 - p2) - (p1) * (1 - p2)
        # Calculate the probabilities of y1=0 and y2=1
        p_y1_0_y2_1 = (1 - p1) + p2 - (1 - p1) * (p2)
        # Calculate the normalization constant Z
        Z = p_y1_1_y2_1 + p_y1_0_y2_0 + p_y1_1_y2_0 + p_y1_0_y2_1
        # R turn a tuple of all four probabilities
        return p_y1_1_y2_1 / Z, p_y1_0_y2_0 / Z, p_y1_1_y2_0 / Z, p_y1_0_y2_1 / Z

    def _marginalized_distribution(self, p1, p2):
        a = self._normalized_distribution_medium(p1, p2)
        p1_1 = a[0] + a[2]
        p1_0 = a[1] + a[3]
        p1 = p1_1 / (p1_0 + p1_1)

        p2_1 = a[0] + a[3]
        p2_0 = a[1] + a[2]
        p2 = p2_1 / (p2_0 + p2_1)

        return (p1, p2)

    def p2_conditional(self, p1, p2, y1):
        a = self._normalized_distribution_medium(p1, p2)
        # divide each element of tensor a[0] by sum of tensors a[0] and a[2]
        p2_1 = torch.div(a[0], (a[0] + a[2]))
        p2_2 = torch.div(a[3], (a[1] + a[3]))
        return torch.where(y1 == 1, p2_1, p2_2)

    def evaluate(self, xs_b):
        w_b1 = self.w_b1.to(xs_b.device)
        p1 = torch.sigmoid(self.scale * (xs_b @ w_b1)[:, :, 0])

        w_b2 = self.w_b2.to(xs_b.device)
        p2 = torch.sigmoid(self.scale * (xs_b @ w_b2)[:, :, 0])

        prob_p1 = self._marginalized_distribution(p1, p2)[0]
        y1 = torch.bernoulli(prob_p1)

        prob_p2_given_y1 = self.p2_conditional(p1, p2, y1)
        y2 = torch.bernoulli(prob_p2_given_y1)
        return torch.stack([y1, y2], axis=0), [prob_p1, prob_p2_given_y1]

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class ThreeNodeTree(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        weight_multiplier=1,
        variable_noise=False,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ThreeNodeTree, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        if pool_dict is None and seeds is None:
            self.w_b1 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            self.w_b2 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier
            self.w_b3 = torch.randn(self.b_size, self.n_dims, 1) * weight_multiplier

        elif seeds is not None:
            self.w_b1 = torch.zeros(self.b_size, self.n_dims, 1)
            self.w_b2 = torch.zeros(self.b_size, self.n_dims, 1)
            self.w_b3 = torch.zeros(self.b_size, self.n_dims, 1)

            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.w_b1[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
                self.w_b2[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )
                self.w_b3[i] = (
                    torch.randn(self.n_dims, 1, generator=generator) * weight_multiplier
                )

        else:
            raise ValueError("CRF_ISING does not support pool_dict")

    # write a normlized distribution for three variables now
    def _normalized_distribution_easy(self, p1, p2, p3):
        p_y1_1_y2_1_y3_1 = p1 + p2 + p3
        p_y1_0_y2_0_y3_0 = (1 - p1) + (1 - p2) + (1 - p3)
        p_y1_1_y2_0_y3_0 = p1 + (1 - p2) + (1 - p3)
        p_y1_0_y2_1_y3_0 = (1 - p1) + p2 + (1 - p3)
        p_y1_0_y2_0_y3_1 = (1 - p1) + (1 - p2) + p3
        p_y1_1_y2_1_y3_0 = p1 + p2 + (1 - p3)
        p_y1_1_y2_0_y3_1 = p1 + (1 - p2) + p3
        p_y1_0_y2_1_y3_1 = (1 - p1) + p2 + p3

        # re-write but with * instead of +
        # p_y1_1_y2_1_y3_1 = p1 * p2 * p3
        # p_y1_0_y2_0_y3_0 = (1 - p1) * (1 - p2) * (1 - p3)
        # p_y1_1_y2_0_y3_0 = p1 * (1 - p2) * (1 - p3)
        # p_y1_0_y2_1_y3_0 = (1 - p1) * p2 * (1 - p3)
        # p_y1_0_y2_0_y3_1 = (1 - p1) * (1 - p2) * p3
        # p_y1_1_y2_1_y3_0 = p1 * p2 * (1 - p3)
        # p_y1_1_y2_0_y3_1 = p1 * (1 - p2) * p3
        # p_y1_0_y2_1_y3_1 = (1 - p1) * p2 * p3

        Z = (
            p_y1_1_y2_1_y3_1
            + p_y1_0_y2_0_y3_0
            + p_y1_1_y2_0_y3_0
            + p_y1_0_y2_1_y3_0
            + p_y1_0_y2_0_y3_1
            + p_y1_1_y2_1_y3_0
            + p_y1_1_y2_0_y3_1
            + p_y1_0_y2_1_y3_1
        )

        # use p_y1_1_y2_1_y3_1... to construct a tensor of dim = (batch_size, 2, 2,2). Populate the tensor with  p_y1_1_y2_1_y3_1, p_y1_0_y2_0_y3_0
        final_probs = torch.zeros((self.b_size, p1.shape[-1], 2, 2, 2))

        final_probs[:, :, 1, 1, 1] = p_y1_1_y2_1_y3_1
        final_probs[:, :, 0, 0, 0] = p_y1_0_y2_0_y3_0
        final_probs[:, :, 1, 0, 0] = p_y1_1_y2_0_y3_0
        final_probs[:, :, 0, 1, 0] = p_y1_0_y2_1_y3_0
        final_probs[:, :, 0, 0, 1] = p_y1_0_y2_0_y3_1
        final_probs[:, :, 1, 1, 0] = p_y1_1_y2_1_y3_0
        final_probs[:, :, 1, 0, 1] = p_y1_1_y2_0_y3_1
        final_probs[:, :, 0, 1, 1] = p_y1_0_y2_1_y3_1

        # divide final_probs (dim = batch_size, 2, 2,2) by Z (dim = batch_size , 1)
        return final_probs / Z.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    def _marginalized_distribution(self, normalized_probs):
        # normalized_probs is a tensor of dim = (batch_size, 2, 2,2)
        # compute the marginals for each variable p1,p2,p3
        py_1 = normalized_probs[:, :, 1, :, :].sum(dim=(2, 3))
        py_2 = normalized_probs[:, :, :, 1, :].sum(dim=(2, 3))
        py_3 = normalized_probs[:, :, :, :, 1].sum(dim=(2, 3))
        return (py_1, py_2, py_3)

    def py_2_y1_conditional(self, normalized_probs, y1):
        # normalized_probs is a tensor of dim = (batch_size, 2, 2,2)
        # compute the conditional probability py_2 | y1 where y1 = y1
        y2_1_temp = normalized_probs[:, :, 1, 1, :].sum(dim=2)
        y2_0_temp = normalized_probs[:, :, 1, 0, :].sum(dim=2)
        y2_1_1 = torch.div(y2_1_temp, y2_1_temp + y2_0_temp)

        y2_1_temp = normalized_probs[:, :, 0, 1, :].sum(dim=2)
        y2_0_temp = normalized_probs[:, :, 0, 0, :].sum(dim=2)
        y2_1_0 = torch.div(y2_1_temp, y2_1_temp + y2_0_temp)

        return torch.where(y1 == 1, y2_1_1, y2_1_0)

    def py_3_y1_conditional(self, normalized_probs, y1):
        # normalized_probs is a tensor of dim = (batch_size, 2, 2,2)
        # compute the conditional probability py_3 | y1 where y1 = y1
        y3_1_temp = normalized_probs[:, :, 1, :, 1].sum(dim=2)
        y3_0_temp = normalized_probs[:, :, 1, :, 0].sum(dim=2)
        y3_1_1 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        y3_1_temp = normalized_probs[:, :, 0, :, 1].sum(dim=2)
        y3_0_temp = normalized_probs[:, :, 0, :, 0].sum(dim=2)
        y3_1_0 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        return torch.where(y1 == 1, y3_1_1, y3_1_0)

    def py_3_y2_y1_conditional(self, normalized_probs, y1, y2):
        # normalized_probs is a tensor of dim = (batch_size, 2, 2,2)
        # compute the conditional probability py3 | y1,y2 where y1 = y1 and y2 = y2
        y3_1_temp = normalized_probs[:, :, 1, 1, 1]
        y3_0_temp = normalized_probs[:, :, 1, 1, 0]
        y3_1_1 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        y3_1_temp = normalized_probs[:, :, 0, 0, 1]
        y3_0_temp = normalized_probs[:, :, 0, 0, 0]
        y3_0_0 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        y3_1_temp = normalized_probs[:, :, 1, 0, 1]
        y3_0_temp = normalized_probs[:, :, 1, 0, 0]
        y3_1_0 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        y3_1_temp = normalized_probs[:, :, 0, 1, 1]
        y3_0_temp = normalized_probs[:, :, 0, 1, 0]
        y3_0_1 = torch.div(y3_1_temp, y3_1_temp + y3_0_temp)

        # convert 1
        selected_tensors = torch.where(
            torch.logical_and(y1 == 1, y2 == 1),
            y3_1_1,
            torch.where(
                torch.logical_and(y1 == 1, y2 == 0),
                y3_1_0,
                torch.where(torch.logical_and(y1 == 0, y2 == 1), y3_0_1, y3_0_0),
            ),
        )

        return selected_tensors

    def evaluate(self, xs_b):
        w_b1 = self.w_b1.to(xs_b.device)
        p1 = torch.sigmoid(self.scale * (xs_b @ w_b1)[:, :, 0])

        w_b2 = self.w_b2.to(xs_b.device)
        p2 = torch.sigmoid(self.scale * (xs_b @ w_b2)[:, :, 0])

        w_b3 = self.w_b2.to(xs_b.device)
        p3 = torch.sigmoid(self.scale * (xs_b @ w_b3)[:, :, 0])

        norm_dist = self._normalized_distribution_easy(p1, p2, p3)

        prob_p1 = self._marginalized_distribution(norm_dist)
        # print values in prob_p1 that are greater than 1 and less than 1

        # take the max of prob_p1[0] and prob_p1 ones tensor
        p1_dist = torch.min(prob_p1[0], torch.ones_like(prob_p1[0]))
        y1 = torch.bernoulli(p1_dist)

        prob_p2_given_y1 = self.py_2_y1_conditional(norm_dist, y1)
        y2 = torch.bernoulli(prob_p2_given_y1)

        prob_p3_given_y1_y2 = self.py_3_y2_y1_conditional(norm_dist, y1, y2)
        y3 = torch.bernoulli(prob_p3_given_y1_y2)

        # create tensor of zeros like y1
        return torch.stack([y1, y2, y3], axis=0), [
            prob_p1[0],
            prob_p2_given_y1,
            prob_p3_given_y1_y2,
        ]

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, **kwargs):  # ignore extra args
        return {"w": torch.randn(num_tasks, n_dims, 1)}

    @staticmethod
    def get_metric():
        return cross_entropy_no_reduction

    @staticmethod
    def get_training_metric():
        return cross_entropy_zero_one


class SparseLinearRegression(LinearRegression):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        sparsity=3,
        valid_coords=None,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(SparseLinearRegression, self).__init__(
            n_dims, batch_size, pool_dict, seeds, scale
        )
        self.sparsity = sparsity
        if valid_coords is None:
            valid_coords = n_dims
        assert valid_coords <= n_dims

        for i, w in enumerate(self.w_b):
            mask = torch.ones(n_dims).bool()
            if seeds is None:
                perm = torch.randperm(valid_coords)
            else:
                generator = torch.Generator()
                generator.manual_seed(seeds[i])
                perm = torch.randperm(valid_coords, generator=generator)
            mask[perm[:sparsity]] = False
            w[mask] = 0

    def evaluate(self, xs_b):
        w_b = self.w_b.to(xs_b.device)
        ys_b = self.scale * (xs_b @ w_b)[:, :, 0]
        return ys_b

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error


class LinearClassification(LinearRegression):
    def evaluate(self, xs_b):
        ys_b = super().evaluate(xs_b)
        return ys_b.sign()

    @staticmethod
    def get_metric():
        return accuracy

    @staticmethod
    def get_training_metric():
        return cross_entropy


class NoisyLinearRegression(LinearRegression):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        noise_std=0,
        renormalize_ys=False,
    ):
        """noise_std: standard deviation of noise added to the prediction."""
        super(NoisyLinearRegression, self).__init__(
            n_dims, batch_size, pool_dict, seeds, scale
        )
        self.noise_std = noise_std
        self.renormalize_ys = renormalize_ys

    def evaluate(self, xs_b):
        ys_b = super().evaluate(xs_b)
        ys_b_noisy = ys_b + torch.randn_like(ys_b) * self.noise_std
        if self.renormalize_ys:
            ys_b_noisy = ys_b_noisy * math.sqrt(self.n_dims) / ys_b_noisy.std()

        return ys_b_noisy


class QuadraticRegression(LinearRegression):
    def evaluate(self, xs_b):
        w_b = self.w_b.to(xs_b.device)
        ys_b_quad = ((xs_b**2) @ w_b)[:, :, 0]
        #         ys_b_quad = ys_b_quad * math.sqrt(self.n_dims) / ys_b_quad.std()
        # Renormalize to Linear Regression Scale
        ys_b_quad = ys_b_quad / math.sqrt(3)
        ys_b_quad = self.scale * ys_b_quad
        return ys_b_quad


class Relu2nnRegression(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(Relu2nnRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size

        if pool_dict is None and seeds is None:
            self.W1 = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
            self.W2 = torch.randn(self.b_size, hidden_layer_size, 1)
        elif seeds is not None:
            self.W1 = torch.zeros(self.b_size, self.n_dims, hidden_layer_size)
            self.W2 = torch.zeros(self.b_size, hidden_layer_size, 1)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.W1[i] = torch.randn(
                    self.n_dims, hidden_layer_size, generator=generator
                )
                self.W2[i] = torch.randn(hidden_layer_size, 1, generator=generator)
        else:
            assert "W1" in pool_dict and "W2" in pool_dict
            assert len(pool_dict["W1"]) == len(pool_dict["W2"])
            indices = torch.randperm(len(pool_dict["W1"]))[:batch_size]
            self.W1 = pool_dict["W1"][indices]
            self.W2 = pool_dict["W2"][indices]

    def evaluate(self, xs_b):
        W1 = self.W1.to(xs_b.device)
        W2 = self.W2.to(xs_b.device)
        # Renormalize to Linear Regression Scale
        ys_b_nn = (torch.nn.functional.relu(xs_b @ W1) @ W2)[:, :, 0]
        ys_b_nn = ys_b_nn * math.sqrt(2 / self.hidden_layer_size)
        ys_b_nn = self.scale * ys_b_nn
        #         ys_b_nn = ys_b_nn * math.sqrt(self.n_dims) / ys_b_nn.std()
        return ys_b_nn

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs):
        return {
            "W1": torch.randn(num_tasks, n_dims, hidden_layer_size),
            "W2": torch.randn(num_tasks, hidden_layer_size, 1),
        }

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error


class DecisionTree(Task):
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None, depth=4):
        super(DecisionTree, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.depth = depth

        if pool_dict is None:
            # We represent the tree using an array (tensor). Root node is at index 0, its 2 children at index 1 and 2...
            # dt_tensor stores the coordinate used at each node of the decision tree.
            # Only indices corresponding to non-leaf nodes are relevant
            self.dt_tensor = torch.randint(
                low=0, high=n_dims, size=(batch_size, 2 ** (depth + 1) - 1)
            )

            # Target value at the leaf nodes.
            # Only indices corresponding to leaf nodes are relevant.
            self.target_tensor = torch.randn(self.dt_tensor.shape)
        elif seeds is not None:
            self.dt_tensor = torch.zeros(batch_size, 2 ** (depth + 1) - 1)
            self.target_tensor = torch.zeros_like(dt_tensor)
            generator = torch.Generator()
            assert len(seeds) == self.b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                self.dt_tensor[i] = torch.randint(
                    low=0,
                    high=n_dims - 1,
                    size=2 ** (depth + 1) - 1,
                    generator=generator,
                )
                self.target_tensor[i] = torch.randn(
                    self.dt_tensor[i].shape, generator=generator
                )
        else:
            raise NotImplementedError

    def evaluate(self, xs_b):
        dt_tensor = self.dt_tensor.to(xs_b.device)
        target_tensor = self.target_tensor.to(xs_b.device)
        ys_b = torch.zeros(xs_b.shape[0], xs_b.shape[1], device=xs_b.device)
        for i in range(xs_b.shape[0]):
            xs_bool = xs_b[i] > 0
            # If a single decision tree present, use it for all the xs in the batch.
            if self.b_size == 1:
                dt = dt_tensor[0]
                target = target_tensor[0]
            else:
                dt = dt_tensor[i]
                target = target_tensor[i]

            cur_nodes = torch.zeros(xs_b.shape[1], device=xs_b.device).long()
            for j in range(self.depth):
                cur_coords = dt[cur_nodes]
                cur_decisions = xs_bool[torch.arange(xs_bool.shape[0]), cur_coords]
                cur_nodes = 2 * cur_nodes + 1 + cur_decisions

            ys_b[i] = target[cur_nodes]

        return ys_b

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks, hidden_layer_size=4, **kwargs):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error
