from typing import Tuple

import numpy as np
import torch
import torch.nn as nn


def get_activation(activation: str) -> nn.Module:
    '''
    param activation: str (activation function)
    return the activation function
    '''
    if activation == 'softmax':
        return nn.Softmax(dim=1)
    elif activation == 'identity':
        return nn.Identity()
    else:
        raise ValueError(f"Invalid activation function: {activation}")


class Attention(nn.Module):
    def __init__(self,
                 d: int,
                 n: int,
                 activation: str,
                 constrained: bool):
        '''
        param d: feature dimension
        param n: context length
        param activation: activation function (e.g. identity, softmax)
        param constrained: whether to use constrained attention parameterization
        '''
        super(Attention, self).__init__()
        self.d = d
        self.n = n
        self.activation = get_activation(activation)
        self.constrained = constrained

        M = torch.eye(n + 1)
        M[-1, -1] = 0
        self.M = M

        if constrained:
            self.A = nn.Parameter(torch.empty(2 * d, 2 * d))
            self.u = nn.Parameter(torch.empty(1, 2 * d))

        else:
            self.P = nn.Parameter(torch.empty(2 * d + 1, 2 * d + 1))
            self.Q = nn.Parameter(torch.empty(2 * d + 1, 2 * d + 1))

    def forward(self, Z: torch.Tensor) -> torch.Tensor:
        '''
        param Z: prompt of shape (batch, 2*d+1, n+1)
        return the updated prompt of shape (batch, 2*d+1, n+1)
        '''
        if self.constrained:
            P = torch.zeros(2 * self.d + 1, 2 * self.d + 1)
            P[-1, -1] = 1
            P[-1, :2 * self.d] = self.u
            Q = torch.zeros(2 * self.d + 1, 2 * self.d + 1)
            Q[:2 * self.d, :2 * self.d] = self.A
        else:
            P = self.P
            Q = self.Q

        X = Z.transpose(-2, -1) @ Q @ Z
        X = self.activation(X)
        return Z + (1.0 / self.n) * P @ Z @ self.M @ X

    def reset_ctxt_len(self, n: int):
        '''
        param n: int (new context length)
        modify the context length of the attention
        '''
        self.n = n
        M = torch.eye(n + 1)
        M[-1, -1] = 0
        self.M = M


class Transformer(nn.Module):
    def __init__(self,
                 d: int,
                 n: int,
                 l: int,
                 activation: str,
                 mode: str,
                 constrained: bool):
        '''
        param d: feature dimension
        param n: context length
        param l: number of layers
        param activation: activation function (e.g. identity, softmax)
        param mode: 'auto' or 'sequential'
        param constrained: whether to use constrained attention parameterization
        '''
        super(Transformer, self).__init__()
        self.d = d
        self.n = n
        self.l = l
        self.activation = activation
        self.mode = mode
        self.constrained = constrained

        if mode == 'auto':
            self.layers = nn.ModuleList([Attention(d, n,
                                                   activation, constrained)])
        elif mode == 'sequential':
            self.layers = nn.ModuleList([Attention(d, n,
                                                   activation, constrained)
                                         for _ in range(l)])
        else:
            raise ValueError('mode must be either auto or sequential')

        if constrained:
            for attn in self.layers:
                nn.init.xavier_normal_(attn.A, gain=0.1/l)
                nn.init.xavier_normal_(attn.u, gain=0.1/l)
        else:
            for attn in self.layers:
                nn.init.xavier_normal_(attn.P, gain=0.1/l)
                nn.init.xavier_normal_(attn.Q, gain=0.1/l)

    def forward(self, Z: torch.Tensor) -> torch.Tensor:
        '''
        param Z: prompt of shape (batch, 2*d+1, n+1)
        return the final embedding of shape (batch, 2*d+1, n+1)
        '''
        if self.mode == 'auto':
            attn = self.layers[0]
            for _ in range(self.l):
                Z = attn(Z)
        else:
            for attn in self.layers:
                Z = attn(Z)
        return Z

    def fit_value_func(self,
                       context: torch.Tensor,
                       phi: torch.Tensor) -> torch.Tensor:
        '''
        param context: the context of shape (2*d+1, n)
        param phi: features of shape (s, d)
        returns the fitted value function given the context in shape (s, 1)
        '''
        assert context.shape[0] == 2 * self.d + 1
        assert context.shape[1] == self.n
        assert phi.shape[1] == self.d

        s = phi.shape[0]
        Zs = torch.zeros(s, 2 * self.d + 1, self.n + 1)
        batched_ctxt = context.unsqueeze(0).repeat(s, 1, 1)
        Zs[:, :, :self.n] = batched_ctxt
        Zs[:, :self.d, self.n] = phi
        return self.pred_v(Zs)

    def pred_v(self, Z: torch.Tensor) -> torch.Tensor:
        '''
        param Z: prompt of shape (batch, 2*d+1, n+1)
        predict the value of the query features (TF(Z))
        '''
        Z_tf = self.forward(Z)
        return -Z_tf[:, -1, [-1]]

    def to_numpy(self) -> Tuple[np.ndarray, np.ndarray]:
        '''
        return the P and Q matrices in numpy format
        '''
        Ps = []
        Qs = []
        for attn in self.layers:
            if self.constrained:
                P = np.zeros((2 * self.d + 1, 2 * self.d + 1))
                P[-1, -1] = 1
                P[-1, :2 * self.d] = attn.u.cpu().detach().numpy()
                Q = np.zeros((2 * self.d + 1, 2 * self.d + 1))
                Q[:2 * self.d, :2 * self.d] = attn.A.cpu().detach().numpy()
            else:
                P = attn.P.cpu().detach().numpy()
                Q = attn.Q.cpu().detach().numpy()
            Ps.append(P)
            Qs.append(Q)
        return np.stack(Ps), np.stack(Qs)

    def reset_ctxt_len(self, n: int):
        '''
        param n: int (new context length)
        modify the context length of the transformer
        '''
        assert n > 0
        self.n = n
        for attn in self.layers:
            attn.reset_ctxt_len(n)

    def copy(self) -> 'Transformer':
        '''
        return a copy of the transformer
        '''
        tf = Transformer(self.d, self.n, self.l,
                         self.activation, self.mode, self.constrained)
        if self.constrained:
            for attn, attn_copy in zip(self.layers, tf.layers):
                attn_copy.A.data = attn.A.data.clone()
                attn_copy.u.data = attn.u.data.clone()
        else:
            for attn, attn_copy in zip(self.layers, tf.layers):
                attn_copy.P.data = attn.P.data.clone()
                attn_copy.Q.data = attn.Q.data.clone()
        return tf
