import torch
import torch.nn as nn
import torch.optim as optim
from gym.spaces.space import Space
from gym.spaces.box import Box
from numpy import finfo, float32


class SequentialNetwork(nn.Module):
    def __init__(self, device, layers=None, code=None, preset=None, input_space=None, output_size=None, normaliser=None,
                 eval_only=False, optimiser=optim.Adam, lr=1e-3, clip_grads=False):
        """
        Net codes:
        - "R"                 = ReLU
        - "T"                 = Tanh
        - "S"                 = Softmax
        - ("D", p)            = Dropout
        - ("B", num_features) = Batch norm
        """
        super(SequentialNetwork, self).__init__() 
        if layers is None: 
            assert input_space is not None and output_size is not None, "Must specify input_space and output_size."
            layers = code_parser(code, input_space, output_size)
        if normaliser == "box_bounds": layers.insert(0, BoxNormalise(space=input_space, device=device))
        elif normaliser is not None: raise NotImplementedError()
        self.layers = nn.Sequential(*layers)
        if eval_only: self.eval()
        else: 
            self.optimiser = optimiser(self.parameters(), lr=lr)
            self.clip_grads = clip_grads
        self.to(device)

    def __repr__(self): return "Net"

    def forward(self, x): return self.layers(x)

    def optimise(self, loss, do_backward=True, retain_graph=True): 
        assert self.training, "Network is in eval_only mode."
        if do_backward: 
            self.optimiser.zero_grad()
            loss.backward(retain_graph=retain_graph) 
        if self.clip_grads: # Optional gradient clipping.
            for param in self.parameters(): param.grad.data.clamp_(-1, 1) 
        self.optimiser.step()

    def polyak(self, other, tau):
        """
        Use Polyak averaging to blend parameters with those of another network.
        """
        for self_param, other_param in zip(self.parameters(), other.parameters()):
            self_param.data.copy_((other_param.data * tau) + (self_param.data * (1.0 - tau)))


def code_parser(code, input_space, output_size):
    # NOTE: Only works for a list of gym Spaces.
    assert type(input_space) == list and all(isinstance(subspace, Space) for subspace in input_space)
    input_size = sum(subspace.shape[0] for subspace in input_space)
    layers = []
    for l in code:
        if type(l) in {list, tuple}:   
            i, o = l[0], l[1]
            if i is None: i = input_size
            if o is None: o = output_size 
            layers.append(nn.Linear(i, o))
        elif l == "R":          layers.append(nn.ReLU())
        elif l == "LR":         layers.append(nn.LeakyReLU())
        elif l == "T":          layers.append(nn.Tanh())
        elif l == "S":          layers.append(nn.Softmax(dim=1))
        elif l[0] == "D":       layers.append(nn.Dropout(p=l[1]))
        elif l[0] == "B":       layers.append(nn.BatchNorm2d(l[1]))
    return layers


class BoxNormalise(nn.Module):
    """
    Normalise into [-1, 1] using the bounds of a list of Box subspaces.
    """

    max_range = finfo(float32).max
    
    def __init__(self, space, device):
        super(BoxNormalise, self).__init__()
        assert type(space) == list and all(isinstance(subspace, Box) for subspace in space)
        rnge, shift = [], []
        for subspace in space:
            r = ((subspace.high - subspace.low) / 2.0)
            assert (r < self.max_range).all(), f"{subspace} has invalid range(s): {r}"
            rnge += list(r)
            shift += list(r + subspace.low)
        self.range, self.shift = torch.tensor(rnge, device=device), torch.tensor(shift, device=device)

    def __repr__(self): return f"BoxNormalise(range={self.range}, shift={self.shift})"

    def forward(self, x): return (x - self.shift) / self.range
