
import torch
import torch.nn as nn
import torch.nn.functional  as F
import random
import math
import numpy as np
from utils import set_seed


from einops import rearrange


def set_model_seed(seed):
    print(f"model seed = {seed}")
    torch.manual_seed(seed)
    random.seed(seed)


# for compatibility
class TF(nn.Module):
    def __init__(self, m=50, d=1000, q=1, 
                 attn_type='softmax', 
                 linear=False, 
                 residual=False,
                 head_type='linear',
                 fixed_qk=False,
                 class_num=2,
        ):
        super(TF, self).__init__()

        self.dim = d
        self.m = m
        
        self.q = q
        self.linear_act = linear
        self.attn_type = attn_type
        self.residual = residual
        self.head_type = head_type
        self.fixed_qk = fixed_qk
        self.class_num = class_num
        
        self.Wq = torch.nn.Parameter(torch.randn(d, d))
        self.Wq.requires_grad = not fixed_qk
        self.Wk = torch.nn.Parameter(torch.randn(d, d))
        self.Wk.requires_grad = not fixed_qk
        self.Wv = torch.nn.Parameter(torch.randn(d, d))
        self.Wv.requires_grad = True

        nn.init.normal_(self.Wq, std=0.003 / math.sqrt(d))
        nn.init.normal_(self.Wk, std=0.003 / math.sqrt(d))
        nn.init.normal_(self.Wv, std=0.003 / math.sqrt(d))

        if self.head_type == 'linear':
            self.W1 = torch.nn.Parameter(torch.randn(self.class_num, d))
            self.W1.requires_grad = True
            nn.init.normal_(self.W1, std=0.003 / math.sqrt(d))
        elif self.head_type == 'mlp2':
            self.W1 = torch.nn.Parameter(torch.randn(m, d))
            self.W2 = torch.nn.Parameter(torch.randn(self.class_num, m))
            self.W1.requires_grad = True
            self.W2.requires_grad = True
            nn.init.normal_(self.W1, std=0.003 / math.sqrt(d))
            nn.init.normal_(self.W2, std=0.003 / math.sqrt(m))
        else:
            raise RuntimeError

    def head(self, input):
        if self.head_type == 'linear':
            return input.matmul(self.W1.T)
        elif self.head_type == 'mlp2':
            return self.act(input.matmul(self.W1.T)).matmul(self.W2.T)
        else:
            raise RuntimeError

    def act(self, input):
        if self.linear_act:
            return input

        return torch.pow(F.relu(input),self.q)
    
    def attn(self, x):
        Q, K, V = torch.matmul(x, self.Wq.T), torch.matmul(x, self.Wk.T), torch.matmul(x, self.Wv.T) # b, t, d
        if self.attn_type == 'softmax':
            attn_weights = torch.softmax(Q.matmul(K.transpose(-1, -2)) / math.sqrt(self.dim), dim=-1) # b, t, t
        elif self.attn_type == 'linear':
            attn_weights = Q.matmul(K.transpose(-1, -2)) / math.sqrt(self.dim) # b, t, t
        else:
            raise RuntimeError
        attn_outputs = attn_weights.matmul(V) # b, t, d
        return attn_outputs, attn_weights

    def forward(self, x, verbose=False):
        # attn part
        attn_outputs, attn_weights = self.attn(x) # b, t, d
        if self.residual:
            attn_outputs = x + attn_outputs

        # head
        out = self.head(attn_outputs) # b, t, k
        Fpn = torch.mean(out, dim=1) # b, k
        if self.class_num == 2:
            out = Fpn[:, 0] - Fpn[:, 1]
            return out
        else:
            return Fpn
    

# for compatibility
def get_model_and_optimizer(n_hidden, n_dim, optim, lr=None, seed_model=0, *args, **kwargs):
    set_model_seed(seed_model)
    model = TF(m=n_hidden, d=n_dim, *args, **kwargs)
    if optim == "adam":
        if lr is None:
            lr = 1e-3
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optim == "sgd":
        if lr is None:
            lr = 0.1
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    elif optim == "sgdm":
        if lr is None:
            lr = 0.1
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    return model, optimizer


class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()

        self.embed_dim = config.hidden_size
        self.dk = config.dk
        self.dv = config.dv
        self.num_heads = config.num_attention_heads
        self.head_dk = self.dk // self.num_heads
        self.head_dv = self.dv // self.num_heads
        self.attn_type = config.attn_type
        self.fixed_qk = config.fixed_qk
        self.fixed_v = config.fixed_v
        self.attn_scaling_type = config.attn_scaling_type
        if config.max_position_embeddings is not None:
            max_positions = config.max_position_embeddings
        else:
            max_positions = 1024
        if config.causal_attention is not None:
            self.causal_attention = config.causal_attention
        else:
            self.causal_attention = False
        print(f"causal_attention: {self.causal_attention}")
        assert self.attn_scaling_type in ["1", "sqrt", "d"]
        assert self.attn_type in ["softmax", "linear", "relu"]
        
        self.Wq = torch.nn.Parameter(torch.randn(self.dk, self.embed_dim))
        self.Wq.requires_grad = not self.fixed_qk
        self.Wk = torch.nn.Parameter(torch.randn(self.dk, self.embed_dim))
        self.Wk.requires_grad = not self.fixed_qk
        self.Wv = torch.nn.Parameter(torch.randn(self.dv, self.embed_dim))
        self.Wv.requires_grad = not self.fixed_v

        nn.init.normal_(self.Wq, std=0.1 / math.sqrt(self.embed_dim))
        nn.init.normal_(self.Wk, std=0.1 / math.sqrt(self.embed_dim))
        if config.val_init_type == "id":
            m = self.dv // 2
            nn.init.eye_(self.Wv[:m])
            nn.init.eye_(self.Wv[m:])
        else:
            nn.init.normal_(self.Wv, std=0.1 / math.sqrt(self.embed_dim))

        self.register_buffer(
            "bias",
            torch.tril(
                torch.ones((max_positions, max_positions), dtype=torch.bool)
            ).view(1, 1, max_positions, max_positions),
            persistent=False,
        )

    def _split_heads(self, tensor, num_heads, attn_head_size):
        """
        Splits hidden_size dim into attn_head_size and num_heads
        """
        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
        tensor = tensor.view(new_shape)
        return tensor.permute(0, 2, 1, 3)  # (batch, head, seqlen, head_features)

    def _merge_heads(self, tensor, num_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        """
        tensor = tensor.permute(0, 2, 1, 3).contiguous() # (batch, seqlen, head, head_features)
        new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
        return tensor.view(new_shape)

    def _attn(self, query, key, value):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.attn_scaling_type == "1":
            attn_weights_scaler = 1.0
        elif self.attn_scaling_type == "sqrt":
            attn_weights_scaler = math.sqrt(self.dk)
        elif self.attn_scaling_type == "d":
            attn_weights_scaler = self.dk

        attn_weights = attn_weights / torch.full(
            [],
            attn_weights_scaler,
            dtype=attn_weights.dtype,
            device=attn_weights.device,
        )

        if self.causal_attention:
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[
                :, :, key_length - query_length : key_length, :key_length
            ]
            if self.attn_type == 'softmax':
                mask_value = torch.finfo(attn_weights.dtype).min
            elif self.attn_type == 'linear':
                mask_value = 0.0
            elif self.attn_type == 'relu':
                mask_value = torch.finfo(attn_weights.dtype).min
            else:
                raise ValueError()

            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
                attn_weights.device
            )
            attn_weights = torch.where(
                causal_mask, attn_weights.to(attn_weights.dtype), mask_value
            )

        if self.attn_type == 'softmax':
            attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        elif self.attn_type == 'linear':
            pass
        elif self.attn_type == 'relu':
            attn_weights = nn.functional.relu(attn_weights)
        else:
            raise ValueError()

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    def forward(self, x, value_fn=None):
        query, key, value = torch.matmul(x, self.Wq.T), torch.matmul(x, self.Wk.T), torch.matmul(x, self.Wv.T) # (batch, seqlen, hidden)
        if value_fn is not None:
            value = value_fn(value)

        query = self._split_heads(query, self.num_heads, self.head_dk)
        key = self._split_heads(key, self.num_heads, self.head_dk)
        value = self._split_heads(value, self.num_heads, self.head_dv)

        attn_output, attn_weights = self._attn(
            query, key, value,
        )

        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dv)

        return attn_output, attn_weights
    

def get_model(config):
    seed_model = config.seed_model
    set_seed(seed_model)
    
    if config.model_cls == "AttnBinary":
        model_cls = AttnBinary
    elif config.model_cls == "AttnBinaryV":
        model_cls = AttnBinaryV
    elif config.model_cls == "AttnFlipflop":
        model_cls = AttnFlipflop
    else:
        raise NotImplementedError
    
    return model_cls(config)


def get_optimizer(model, opt_cls, **kwargs):
    return opt_cls(model.parameters(), **kwargs)


class AttnBinary(nn.Module):
    def __init__(self, config):
        """share attention, nonlinearity after weighted sum"""
        super(AttnBinary, self).__init__()

        config.dv *= 2
        self.attn = Attention(config)
        self.linear = config.act_linear
        self.q = config.act_q
        self.name = "AttnBinary"
        self.patch_size = config.patch_size

    def act(self,input):
        if self.linear:
            return input
        return torch.pow(F.relu(input),self.q)
    
    def forward(self, x):
        if x.dim() == 3 and x.size(1) == x.size(2):
            x = rearrange(x, 'b (h p1) (w p2) -> b (h w) (p1 p2)', p1 = self.patch_size, p2 = self.patch_size)
        attn_output, attn_weights = self.attn(x) # b, t, dv
        m = self.attn.dv // 2
        if x.size(1) == 2:
            Fp = torch.mean(self.act(attn_output[:, 0, :m]), dim=-1) \
                + torch.mean(self.act(attn_output[:, 1, :m]), dim=-1)
            Fn = torch.mean(self.act(attn_output[:, 0, m:]), dim=-1) \
                + torch.mean(self.act(attn_output[:, 1, m:]), dim=-1)
            out = Fp - Fn
        else:
            acted_attn_output = self.act(attn_output)
            Fp = torch.sum(torch.mean(acted_attn_output[:, :, :m], dim=-1), dim=-1)
            Fn = torch.sum(torch.mean(acted_attn_output[:, :, m:], dim=-1), dim=-1)
            out = Fp - Fn

        result = dict(
            out=out,
            attn_weights=attn_weights,
            attn_weights_p=None,
            attn_weights_n=None,
            attn_output=attn_output,
            attn_output_p=None,
            attn_output_n=None,
        )
        return result
    

class AttnBinaryV(AttnBinary):
    def __init__(self, config):
        """share attention, nonlinearity on value tensors"""
        super(AttnBinaryV, self).__init__(config=config)
        self.name = "AttnBinaryV"
    
    def forward(self, x):
        assert x.size(1) == 2
        attn_output, attn_weights = self.attn(x, value_fn=self.act) # b, t, dv
        m = self.attn.dv // 2
        Fp = torch.mean(attn_output[:, 0, :m], 1) \
            + torch.mean(attn_output[:, 1, :m], 1)
        Fn = torch.mean(attn_output[:, 0, m:], 1) \
            + torch.mean(attn_output[:, 1, m:], 1)
        out = Fp - Fn

        result = dict(
            out=out,
            attn_weights=attn_weights,
            attn_weights_p=None,
            attn_weights_n=None,
            attn_output=attn_output,
            attn_output_p=None,
            attn_output_n=None,
        )
        return result


class AttnBinaryNotShare(nn.Module):
    def __init__(self, config):
        """not share attention, nonlinearity after weighted sum"""
        super(AttnBinaryNotShare, self).__init__()

        self.attn_p = Attention(config)
        self.attn_n = Attention(config)
        self.linear = config.act_linear
        self.q = config.act_q
        self.name = "AttnBinaryNotShare"

    def act(self,input):
        if self.linear:
            return input
        return torch.pow(F.relu(input),self.q)
    
    def forward(self, x):
        assert x.size(1) == 2
        attn_output_p, attn_weights_p = self.attn_p(x) # b, 2, dv
        attn_output_n, attn_weights_n = self.attn_n(x) # b, t, dv

        Fp = torch.mean(self.act(attn_output_p[:, 0]), 1) \
            + torch.mean(self.act(attn_output_p[:, 1]), 1)
        Fn = torch.mean(self.act(attn_output_n[:, 0]), 1) \
            + torch.mean(self.act(attn_output_n[:, 1]), 1)
        out = Fp - Fn

        result = dict(
            out=out,
            attn_weights=None,
            attn_weights_p=attn_weights_p,
            attn_weights_n=attn_weights_n,
            attn_output=None,
            attn_output_p=attn_output_p,
            attn_output_n=attn_output_n,
        )
        return result
    

class AttnFlipflop(nn.Module):
    def __init__(self, config):
        """share attention, nonlinearity after weighted sum"""
        super(AttnFlipflop, self).__init__()
        self.name = "AttnFlipflop"
        self.config = config

        self.embed = nn.Embedding(5, config.hidden_size)
        self.embed.weight.data = torch.eye(self.config.hidden_size)[:5]
        self.embed.weight.requires_data = False

        self.attn = Attention(config)
        self.linear = config.act_linear
        self.q = config.act_q
        self.head = nn.Linear(config.dv, config.vocab_size)

    def act(self,input):
        if self.linear:
            return input
        return torch.pow(F.relu(input),self.q)
    
    def forward(self, input_ids, **kwargs):
        # embed 
        x = self.embed(input_ids)
        # print(input_ids.shape, x.device, x.shape)
        # print(self.attn.Wq.shape, self.attn.Wq.device)
        # print(self.attn.Wk.shape, self.attn.Wk.device)
        # print(self.attn.Wv.shape, self.attn.Wv.device)
        # print(x)

        attn_output, attn_weights = self.attn(x) # b, t, dv
        out = self.head(attn_output)

        result = dict(
            out=out,
            attn_weights=attn_weights,
            attn_output=attn_output,
        )
        return result
