import math
import numpy as np
import torch
import torch.nn as nn
import random


def squared_error(ys_pred, ys):
    return (ys - ys_pred).square()


class Task:
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
        self.n_dims = n_dims
        self.b_size = batch_size
        self.pool_dict = pool_dict
        self.seeds = seeds
        assert pool_dict is None or seeds is None

    def evaluate(self, xs):
        raise NotImplementedError

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        raise NotImplementedError

    @staticmethod
    def get_training_metric():
        raise NotImplementedError


def get_task_sampler(
    task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs
):
    task_names_to_classes = {
        "relu_nn_regression": ReluNNRegression,
        "long_chain": LongChain,
        "cot_skill_chain": CoTSkillChain,
        "relu_nn_regression_asymmetric": ReluNNRegressionAsymmetric
    }
    if task_name in task_names_to_classes:
        task_cls = task_names_to_classes[task_name]
        if num_tasks is not None:
            if pool_dict is not None:
                raise ValueError("Either pool_dict or num_tasks should be None.")
            pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs)
        return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs)
    else:
        print("Unknown task")
        raise NotImplementedError

# TODO
class CoTSkillChain(Task):
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None,
        n_skills=10, n_funcs=20, noise=0.1, seed=0, scale=1,  
        min_chain_length=1, max_chain_length=2, chain_length=5, ordered_chain=False, mode='linear'):
        super(CoTSkillChain, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.n_skills = n_skills
        self.n_funcs = n_funcs
        self.min_chain_length = min_chain_length
        self.max_chain_length = max_chain_length
        self.chain_length = 5 #chain_length
        self.ordered_chain = ordered_chain
        self.mode = mode
        self.noise = 0 # noise

        rng = np.random.default_rng(seed)
        # self.func_set = torch.zeros(n_skills, n_funcs, n_dims, n_dims)
        # for i in range(n_skills):
        #     for j in range(n_funcs):
        #         self.func_set[i,j] = torch.linalg.svd(torch.randn(n_dims, n_dims, generator=generator))[0]
        self.func_set = np.linalg.svd(rng.standard_normal((n_skills, n_funcs, n_dims, n_dims)))[0]
        # print(self.func_set[0,1].dot( self.func_set[0,1].T))
        
        if mode == 'linear':
            def func(x):
                return x
            self.act_func = func
        else:
            raise ValueError('No vaild mode found')


    def __get_chain(self, xs, func_ids, skill_ids, eval):
        chain_length = skill_ids.shape[1] 

        bsize, n_points, n_dims = xs.shape
        skill_ids = np.concatenate([skill_ids, -np.ones((bsize,1), dtype=int)], axis=1)
        if not eval:
            # n_sample = list(np.random.randint(self.min_chain_length, self.max_chain_length+1, (bsize, n_points)))
            n_sample = np.random.choice([1,2], size=(bsize, n_points), p=(0.2, 0.8))
            start_ids = list(map(lambda n:  np.random.randint(chain_length-n+1), n_sample))
        else:
            n_sample = list(np.ones((bsize, n_points), dtype=int)*2)
            start_ids = np.arange(chain_length-1).reshape(1,-1).repeat(n_points,axis=0).reshape(1,-1)[:,:n_points].repeat(bsize,axis=0)
            for i in range((n_points//(chain_length-1))):
                np.random.shuffle(start_ids[:,i*(chain_length-1):(i+1)*(chain_length-1)].T)
            start_ids = list(start_ids)
        end_ids = list(map(lambda id, n: id + n, start_ids, n_sample))

        ids = [list(map(lambda sid, eid: list(range(sid, eid)), sid, eid)) for sid, eid in zip(start_ids, end_ids)]
        ids_func = np.array([np.concatenate(list(map(lambda i, j: (chain_length+1)*j+np.array(i + [i[-1] + 1]), id, range(n_points))))[:n_points] for id in ids])
        random_noise = np.random.randn(chain_length,bsize,n_points,n_dims)
        ys = [xs]
        for i in range(self.chain_length):
            temp = self.func_set[skill_ids[:,i],func_ids[:,i]]
            xs = self.act_func(xs @ temp + self.noise * random_noise[i])
            ys += [xs]
        # ys += [np.zeros_like(xs)]
        ys += [xs]
        ys = np.stack(ys, axis=2)
        ys = ys.reshape(bsize, -1, n_dims)
        ys_ = np.array(list(map(lambda i, id: ys[i,id], range(bsize), ids_func)))
        ids_skill = [np.concatenate(list(map(lambda i: list(i) + [chain_length], id)))[:n_points] for id in ids]
        ids_skill = [oid[id] for oid,id in zip(skill_ids, ids_skill)]

        return ys_, ids_skill


    def evaluate(self, xs_b, eval=False):
        xs_b = xs_b.numpy()
        n_skills = self.n_skills
        n_funcs = self.n_funcs
        chain_length = self.chain_length
        bsize, n_points, dim = xs_b.shape

        if self.ordered_chain:
            if n_skills != chain_length:
                raise ValueError(f"number of skills ({n_skills}) != chain length ({chain_length})")
            skill_ids = np.array([np.arange(n_skills+1) for _ in range(bsize)])
        else:
            skill_ids = np.array([np.random.permutation(n_skills)[:chain_length] for _ in range(bsize)])
            skill_ids = np.concatenate([skill_ids, np.ones((bsize,1), dtype=int)*n_skills], axis=1)
        func_ids = np.random.randint(n_funcs, size=(bsize, chain_length,))

        ys_b, ids_b = self.__get_chain(xs_b, func_ids, skill_ids, eval)   
        # print(ys_b.shape, np.array(ids_b).shape, skill_ids.shape, func_ids.shape)
        # print(ys_b[0,:,0], np.array(ids_b)[0],skill_ids[0], func_ids[0])
        # exit()

        return torch.from_numpy(ys_b).float(), torch.from_numpy(np.array(ids_b)), skill_ids, func_ids

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class LongChain(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=5,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(LongChain, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if hidden_layer_size != n_dims:
            raise ValueError(f"hidden layer size ({hidden_layer_size}) != dimension ({n_dims})")

        self.n_funcs = 4
        rng = np.random.default_rng(0)
        self.func_set = np.linalg.svd(rng.standard_normal((self.n_funcs, n_dims, n_dims)))[0]
        self.func_set = torch.tensor(self.func_set).float()

    # def evaluate(self, xs_b):
    #     bsize = xs_b.shape[0]
    #     order_ids = np.array([np.random.permutation(self.n_funcs)[:self.n_layers] for _ in range(bsize)])
    #     Ws = self.func_set.to(xs_b.device)[order_ids]
    #     layer_activations = []
    #     activ = xs_b
    #     for i in range(self.n_layers-1):
    #         activ = activ @ Ws[:,i]
    #         layer_activations.append(activ)
    #     ys_b_nn = (activ @ Ws[:,-1])     
    #     return ys_b_nn, layer_activations
    
    # def evaluate(self, xs_b):
    #     bsize = xs_b.shape[0]
    #     # order_ids = np.array([np.random.permutation(self.n_funcs)[:8] for _ in range(bsize)])
    #     order_ids = np.random.randint(0, self.n_funcs, size=(bsize, 6))
    #     if self.n_layers == 5:
    #         sample_ids = [0,1,2,3,4]
    #     elif self.n_layers == 2:
    #         sample_ids = [1,3]
    #     elif self.n_layers == 1:
    #         sample_ids = [2]
    #     elif self.n_layers == 0:
    #         sample_ids = []
    #     else: 
    #         raise ValueError(f"n_layers={self.n_layers} is not applicable.")
    #     # if self.n_layers == 7:
    #     #     sample_ids = [0,1,2,3,4,5,6]
    #     # elif self.n_layers == 3:
    #     #     sample_ids = [1,3,5]
    #     # elif self.n_layers == 1:
    #     #     sample_ids = [3]
    #     # elif self.n_layers == 0:
    #     #     sample_ids = []
    #     # else: 
    #     #     raise ValueError(f"n_layers={self.n_layers} is not applicable.")
    #     Ws = self.func_set.to(xs_b.device)[order_ids]
    #     layer_activations = []
    #     activ = xs_b
    #     for i in range(5):
    #         activ = activ @ Ws[:,i]
    #         if i in sample_ids:
    #             layer_activations.append(activ)
    #     ys_b_nn = (activ @ Ws[:,-1])     
    #     return ys_b_nn, layer_activations

    def evaluate(self, xs_b, return_correct=False):
        bsize = xs_b.shape[0]
        # order_ids = np.array([np.random.permutation(self.n_funcs)[:8] for _ in range(bsize)])
        order_ids = np.random.randint(0, self.n_funcs, size=(bsize, 6))
        if self.n_layers == 5:
            sample_ids = [0,1,2,3,4]
        elif self.n_layers == 2:
            sample_ids = [1,3]
        elif self.n_layers == 1:
            sample_ids = [2]
        elif self.n_layers == 0:
            sample_ids = []
        else: 
            raise ValueError(f"n_layers={self.n_layers} is not applicable.")
        # if self.n_layers == 7:
        #     sample_ids = [0,1,2,3,4,5,6]
        # elif self.n_layers == 3:
        #     sample_ids = [1,3,5]
        # elif self.n_layers == 1:
        #     sample_ids = [3]
        # elif self.n_layers == 0:
        #     sample_ids = []
        # else: 
        #     raise ValueError(f"n_layers={self.n_layers} is not applicable.")

        rng = np.random.default_rng(np.random.randint(10000000))
        garbage = np.linalg.svd(rng.standard_normal((6, 64, self.n_dims, self.n_dims)))[0]
        garbage = torch.tensor(garbage).float()
        # garbage = torch.zeros((6,self.n_dims,self.n_dims))
        # order_ids_g = np.random.randint(0, self.n_funcs, size=(bsize, 6))
        # garbage = self.func_set.to(xs_b.device)[order_ids_g]

        Ws = self.func_set.to(xs_b.device)[order_ids]
        layer_activations = []
        layer_activations_c = []
        activ = xs_b
        for i in range(5):
            activ = activ @ Ws[:,i]
            if i in sample_ids:
                layer_activations_c.append(activ)
                # if np.random.randint(10)<90:
                if i in [0,1]:
                    layer_activations.append(activ)
                else:
                    # random_id = np.random.randint(4)
                    # g = self.func_set.to(xs_b.device)[random_id]
                    # layer_activations.append(activ @ g)
                    layer_activations.append(activ @ garbage[i])
        ys_b_nn_c = (activ @ Ws[:,-1])
        # if np.random.randint(10)<90:
        if False:
            ys_b_nn = (activ @ Ws[:,-1])     
        else:
            # random_id = np.random.randint(4)
            # g = self.func_set.to(xs_b.device)[random_id]
            # ys_b_nn = (activ @ g)
            ys_b_nn = (activ @ garbage[5])
        if not return_correct:
            return ys_b_nn, layer_activations
        else:
            return ys_b_nn, layer_activations, ys_b_nn_c, layer_activations_c

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegression(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=5,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        # Sigma = torch.ones((self.n_dims,self.n_dims))*0.9 + torch.diag(0.1*torch.ones(self.n_dims))
        # u, s, v = torch.svd(Sigma)
        # s = torch.sqrt(s)
        # Sigma = u @ torch.diag(s) @ v.T
        # option 1
        self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_size, 1)
        # option 2
        # self.W_init = torch.randn(self.b_size, self.n_dims, 16)
        # self.Ws = torch.randn(self.n_layers-2, self.b_size, 16, 16)
        # self.v = torch.randn(self.b_size, 16, 1)
        # option 3
        # self.hidden_layer_size = 1
        # self.W_init = torch.randn(self.b_size, self.n_dims, 1)
        # option 4
        # self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        # option 5
        # self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        # self.v = torch.randn(self.b_size, hidden_layer_size, 1)

        

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b, return_correct=False):
        # option 3
        # W_init = self.W_init.to(xs_b.device)
        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # return activ[:,:,0], layer_activations
        # option 4
        # W_init = self.W_init.to(xs_b.device)
        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # return activ, layer_activations
        # option 5
        # W_init = self.W_init.to(xs_b.device)
        # v = self.v.to(xs_b.device)
        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / 16) * self.scale
        # # layer_activations = [activ[:,:,:self.hidden_layer_size]]
        # ys_b_nn = (activ @ v)[:, :, 0]        
        # return ys_b_nn, layer_activations
    

        W_init = self.W_init.to(xs_b.device)
        Ws = self.Ws.to(xs_b.device)
        v = self.v.to(xs_b.device)

        # eigenvals = 1 / (torch.arange(n_dims) + 1)

        # scale = sample_transformation(eigenvals, normalize=True)
        # evaluation_kwargs[f"{method}"] = {
        #     "data_sampler_kwargs": {"scale": scale},
        # }

        ####
        temp = 1 #np.random.randint(2)
        ####

        activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        layer_activations_c = [activ]
        if temp != 0:
            layer_activations = [activ]
        else:
            layer_activations = [torch.zeros_like(activ)]
            # layer_activations = [self.act_func(torch.randn_like(xs_b) @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale]
        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / 16) * self.scale
        # layer_activations = [activ[:,:,:self.hidden_layer_size]]
        for i in range(self.n_layers-2):
            activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
            layer_activations.append(activ)
            # activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / 16) * self.scale
            # layer_activations.append(activ[:,:,:self.hidden_layer_size])
        ys_b_nn_c = (activ @ v)[:, :, 0]        
        if temp == 0:
            ys_b_nn = (activ @ v)[:, :, 0]
        else:
            # ys_b_nn = torch.zeros_like(ys_b_nn_c)
            ys_b_nn = (self.act_func(torch.randn_like(xs_b) @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale @ torch.randn_like(v))[:, :, 0]
        if return_correct:
            return ys_b_nn, layer_activations, ys_b_nn_c, layer_activations_c
        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegressionAsymmetric(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=4,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        hidden_layer_sizes = [hidden_layer_size for i in range(n_layers-1)]
        # make it asymmetric by making the last layer twice as wide
        hidden_layer_sizes[-1] = 2*hidden_layer_size
        self.hidden_layer_sizes = hidden_layer_sizes

        if len(hidden_layer_sizes) != n_layers-1:
            raise ValueError("hidden_layer_sizes={} not compatible with n_layers={}.".format(hidden_layer_sizes, n_layers))

        self.Ws = []
        for layer_idx in range(n_layers-1):
            if layer_idx == 0:
                self.Ws.append(torch.randn(self.b_size, self.n_dims, hidden_layer_sizes[layer_idx]))
            else:
                self.Ws.append(torch.randn(self.b_size, hidden_layer_sizes[layer_idx-1], hidden_layer_sizes[layer_idx]))

        # self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        # self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_sizes[-1], 1)

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b):
        for W in self.Ws:
            W.to(xs_b.device)
        v = self.v.to(xs_b.device)

        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # for i in range(self.n_layers-2):
        #     activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        #     layer_activations.append(activ)
        # ys_b_nn = (activ @ v)[:, :, 0]

        layer_activations = []
        for layer_idx in range(self.n_layers-1):
            if layer_idx == 0:
                activ = self.act_func(xs_b @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            else:
                activ = self.act_func(layer_activations[layer_idx-1] @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            layer_activations.append(activ)
        ys_b_nn = (layer_activations[-1] @ v)[:, :, 0]

        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


# TODO: @ks: need to check with @yingcong if we need this
class TaskFamily:
    def __init__(self):
        super(TaskFamily, self).__init__()
        ## TODO: we need define tasks that can input different dims (do not have to inpout 1-dim only)
        # Then we can include out_dim in task_kwargs
        self.task_mapping = {
            0: {
                'task_name' : 'linear_regression',
                'n_dims' : 10,
                'task_kwargs' : None
            }
        }

    # Given identifiers, input x, then output y
    # For an example: (x,1,2,3)->(f1(x),f2(f1(x)),f3(f2(f1(x))))
    def evaluate(self, xs_b, ids):
        bsize = xs_b.shape[0]
        points = xs_b.shape[1]
        dims = xs_b.shape[2]
        ys_b = []
        for id in ids:
            task_name = self.task_mapping[id]['task_name']
            n_dims = self.task_mapping[id]['n_dims']
            task_kwargs = self.task_mapping[id]['task_kwargs']
            # To adapt the output to its next input
            # Another option: fail when it does not fit
            if n_dims > dims:
                xs_b = torch.cat(
                    (
                        xs_b, torch.zeros(bsize, points, n_dims - dims, device=xs_b.device)
                    ),
                    axis=2,
                )
            elif n_dims < dims:
                xs_b = xs_b[:,:,:n_dims]
            task_sampler = get_task_sampler(task_name, n_dims, bsize, **task_kwargs)
            task = task_sampler()
            xs_b = task.evaluate(xs_b)
            ys_b.append(xs_b)

        return ys_b

    @staticmethod
    def get_metric():
        def squared_error(ys_pred, ys):
            return (ys - ys_pred).square().mean(-1)
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error

