import math
import numpy as np
import torch
import torch.nn as nn
import random


def squared_error(ys_pred, ys):
    return (ys - ys_pred).square()


class Task:
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None):
        self.n_dims = n_dims
        self.b_size = batch_size
        self.pool_dict = pool_dict
        self.seeds = seeds
        assert pool_dict is None or seeds is None

    def evaluate(self, xs):
        raise NotImplementedError

    @staticmethod
    def generate_pool_dict(n_dims, num_tasks):
        raise NotImplementedError

    @staticmethod
    def get_metric():
        raise NotImplementedError

    @staticmethod
    def get_training_metric():
        raise NotImplementedError


def get_task_sampler(
    task_name, n_dims, batch_size, pool_dict=None, num_tasks=None, **kwargs
):
    task_names_to_classes = {
        "relu_nn_regression": ReluNNRegression,
        "cot_skill_chain": CoTSkillChain,
        "relu_nn_regression_asymmetric": ReluNNRegressionAsymmetric
    }
    if task_name in task_names_to_classes:
        task_cls = task_names_to_classes[task_name]
        if num_tasks is not None:
            if pool_dict is not None:
                raise ValueError("Either pool_dict or num_tasks should be None.")
            pool_dict = task_cls.generate_pool_dict(n_dims, num_tasks, **kwargs)
        return lambda **args: task_cls(n_dims, batch_size, pool_dict, **args, **kwargs)
    else:
        print("Unknown task")
        raise NotImplementedError

# TODO
class CoTSkillChain(Task):
    def __init__(self, n_dims, batch_size, pool_dict=None, seeds=None,
        n_skills=10, n_funcs=20, noise=0.1, seed=0, scale=1,  
        min_chain_length=1, max_chain_length=2, chain_length=5, ordered_chain=False, mode='linear'):
        super(CoTSkillChain, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.n_skills = n_skills
        self.n_funcs = n_funcs
        self.min_chain_length = min_chain_length
        self.max_chain_length = max_chain_length
        self.chain_length = chain_length
        self.ordered_chain = ordered_chain
        self.mode = mode
        self.noise = noise

        generator = torch.Generator()
        generator.manual_seed(seed)
        # self.func_set = torch.zeros(n_skills, n_funcs, n_dims, n_dims)
        # for i in range(n_skills):
        #     for j in range(n_funcs):
        #         self.func_set[i,j] = torch.linalg.svd(torch.randn(n_dims, n_dims, generator=generator))[0]
        self.func_set = torch.linalg.svd(torch.randn(n_skills, n_funcs, n_dims, n_dims, generator=generator))[0]
        
        if mode == 'linear':
            self.act_func = nn.Identity()
        elif mode == 'relu':
            self.act_func = nn.ReLU()
        elif mode == 'tanh':
            self.act_func = nn.Tanh()
        else:
            raise ValueError('No vaild mode found')


    def get_chain(self, xs, func_ids, skill_ids):
        chain_length = self.chain_length
        if len(skill_ids) != chain_length:
            raise ValueError(f"number of skills({len(skill_ids)}) not equal to the chain length({chain_length})")
       
        n_points, _ = xs.shape
        n_sample = torch.randint(self.min_chain_length, self.max_chain_length+1, (n_points,))
        start_ids = [np.random.randint(chain_length-n+1) for n in n_sample]

        sids = []
        outputs = []
        for i in range(n_points):
            x = xs[i]
            for j in range(start_ids[i]):
                sid = skill_ids[j]
                fid = func_ids[j]
                x = self.act_func(x @ self.func_set[sid,fid])
                x += self.noise * torch.randn_like(x)
            outputs.append(x)
            for j in range(start_ids[i], start_ids[i]+n_sample[i]):
                sid = skill_ids[j]
                sids.append(sid)
                fid = func_ids[j]
                x = self.act_func(x @ self.func_set[sid,fid])
                x += self.noise * torch.randn_like(x)
                outputs.append(x)
            # if start_ids[i]+n_sample[i] == chain_length:
            #     sids.append(self.n_skills)   
            #     # outputs.append(torch.zeros_like(x))
            #     outputs.append(x)
            #     sids.append(-1) 
            # else:
            #     sids.append(-1)
            sids.append(-1)
            if len(sids) >= n_points:
                break
        outputs = torch.stack(outputs)
        sids = torch.tensor(sids)
        return outputs[:n_points], sids[:n_points]

    def evaluate(self, xs_b):
        n_skills = self.n_skills
        n_funcs = self.n_funcs
        chain_length = self.chain_length
        bsize, n_points, dim = xs_b.shape
        ys_b = torch.zeros((bsize, n_points, dim))
        ids_b = torch.zeros((bsize, n_points)).to(int)
        
        skill_ids = torch.zeros(bsize, chain_length).to(int)
        if self.ordered_chain:
            for batch_idx in range(bsize):
                start_idx = np.random.randint(n_skills-chain_length+1)
                skill_ids[batch_idx] = torch.arange(start_idx, start_idx+chain_length)
        else:
            for batch_idx in range(bsize):
                skill_ids[batch_idx] = torch.randperm(n_skills)[:chain_length]
            
        func_ids = torch.randint(n_funcs, (bsize, chain_length,))
        for batch_idx in range(bsize):                
            ys_b[batch_idx], ids_b[batch_idx] = self.get_chain(xs_b[batch_idx], func_ids[batch_idx], skill_ids[batch_idx])

        return ys_b, ids_b, skill_ids, func_ids

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegression(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=5,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_size, 1)

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b):
        W_init = self.W_init.to(xs_b.device)
        Ws = self.Ws.to(xs_b.device)
        v = self.v.to(xs_b.device)

        activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        layer_activations = [activ]
        for i in range(self.n_layers-2):
            activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
            layer_activations.append(activ)
        ys_b_nn = (activ @ v)[:, :, 0]        
        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


class ReluNNRegressionAsymmetric(Task):
    def __init__(
        self,
        n_dims,
        batch_size,
        pool_dict=None,
        seeds=None,
        scale=1,
        hidden_layer_size=4,
        n_layers=4,
        mode='relu'
    ):
        """scale: a constant by which to scale the randomly sampled weights."""
        super(ReluNNRegression, self).__init__(n_dims, batch_size, pool_dict, seeds)
        self.scale = scale
        self.hidden_layer_size = hidden_layer_size
        self.n_layers = n_layers
        if n_layers < 2:
            raise ValueError("Number of layers should not be smaller than 2.")

        hidden_layer_sizes = [hidden_layer_size for i in range(n_layers-1)]
        # make it asymmetric by making the last layer twice as wide
        hidden_layer_sizes[-1] = 2*hidden_layer_size
        self.hidden_layer_sizes = hidden_layer_sizes

        if len(hidden_layer_sizes) != n_layers-1:
            raise ValueError("hidden_layer_sizes={} not compatible with n_layers={}.".format(hidden_layer_sizes, n_layers))

        self.Ws = []
        for layer_idx in range(n_layers-1):
            if layer_idx == 0:
                self.Ws.append(torch.randn(self.b_size, self.n_dims, hidden_layer_sizes[layer_idx]))
            else:
                self.Ws.append(torch.randn(self.b_size, hidden_layer_sizes[layer_idx-1], hidden_layer_sizes[layer_idx]))

        # self.W_init = torch.randn(self.b_size, self.n_dims, hidden_layer_size)
        # self.Ws = torch.randn(self.n_layers-2, self.b_size, hidden_layer_size, hidden_layer_size)
        self.v = torch.randn(self.b_size, hidden_layer_sizes[-1], 1)

        if mode == 'relu':
            self.act_func = torch.nn.ReLU()
        elif mode == 'tanh':
            self.act_func = torch.nn.Tanh()
        else:
            raise NotImplementedError


    def evaluate(self, xs_b):
        for W in self.Ws:
            W.to(xs_b.device)
        v = self.v.to(xs_b.device)

        # activ = self.act_func(xs_b @ W_init) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        # layer_activations = [activ]
        # for i in range(self.n_layers-2):
        #     activ = self.act_func(activ @ Ws[i]) * math.sqrt(2 / self.hidden_layer_size) * self.scale
        #     layer_activations.append(activ)
        # ys_b_nn = (activ @ v)[:, :, 0]

        layer_activations = []
        for layer_idx in range(self.n_layers-1):
            if layer_idx == 0:
                activ = self.act_func(xs_b @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            else:
                activ = self.act_func(layer_activations[layer_idx-1] @ self.Ws[layer_idx]) * math.sqrt(2 / self.hidden_layer_sizes[layer_idx]) * self.scale
            layer_activations.append(activ)
        ys_b_nn = (layer_activations[-1] @ v)[:, :, 0]

        return ys_b_nn, layer_activations

    @staticmethod
    def get_metric():
        return squared_error

    @staticmethod
    def get_training_metric():
        return squared_error


# TODO: @ks: need to check with @yingcong if we need this
class TaskFamily:
    def __init__(self):
        super(TaskFamily, self).__init__()
        ## TODO: we need define tasks that can input different dims (do not have to inpout 1-dim only)
        # Then we can include out_dim in task_kwargs
        self.task_mapping = {
            0: {
                'task_name' : 'linear_regression',
                'n_dims' : 10,
                'task_kwargs' : None
            }
        }

    # Given identifiers, input x, then output y
    # For an example: (x,1,2,3)->(f1(x),f2(f1(x)),f3(f2(f1(x))))
    def evaluate(self, xs_b, ids):
        bsize = xs_b.shape[0]
        points = xs_b.shape[1]
        dims = xs_b.shape[2]
        ys_b = []
        for id in ids:
            task_name = self.task_mapping[id]['task_name']
            n_dims = self.task_mapping[id]['n_dims']
            task_kwargs = self.task_mapping[id]['task_kwargs']
            # To adapt the output to its next input
            # Another option: fail when it does not fit
            if n_dims > dims:
                xs_b = torch.cat(
                    (
                        xs_b, torch.zeros(bsize, points, n_dims - dims, device=xs_b.device)
                    ),
                    axis=2,
                )
            elif n_dims < dims:
                xs_b = xs_b[:,:,:n_dims]
            task_sampler = get_task_sampler(task_name, n_dims, bsize, **task_kwargs)
            task = task_sampler()
            xs_b = task.evaluate(xs_b)
            ys_b.append(xs_b)

        return ys_b

    @staticmethod
    def get_metric():
        def squared_error(ys_pred, ys):
            return (ys - ys_pred).square().mean(-1)
        return squared_error

    @staticmethod
    def get_training_metric():
        return mean_squared_error

