from model import common_qipt
from model import common
from model import Quant
import math
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from einops import rearrange
import copy

from .lsq_plus import *
from ._quan_base_plus import *

from .attention_layer import *


def make_model(args, parent=False):
    return ipt(args)

class ipt(nn.Module):
    def __init__(self, args, conv=Conv2dLSQ):
        super(ipt, self).__init__()
        
        self.scale_idx = 0
        
        self.args = args
        nbits = 4
        
        n_feats = args.n_feats
        kernel_size = 3 
        act = nn.ReLU(True)

        self.sub_mean = common_qipt.MeanShift(args.rgb_range)
        self.add_mean = common_qipt.MeanShift(args.rgb_range, sign=1)
        # import pdb; pdb.set_trace()
        self.head = nn.ModuleList([
            nn.Sequential(
                common_qipt.default_conv(args.n_colors, n_feats, kernel_size),
                common_qipt.ResBlock(conv, n_feats, 5, act=act, nbits=nbits),
                common_qipt.ResBlock(conv, n_feats, 5, act=act, nbits=nbits)
            ) for _ in args.scale
        ])

        self.body = VisionTransformer(img_dim=args.patch_size, patch_dim=args.patch_dim, num_channels=n_feats, embedding_dim=n_feats*args.patch_dim*args.patch_dim, num_heads=args.num_heads, num_layers=args.num_layers, hidden_dim=n_feats*args.patch_dim*args.patch_dim*4, num_queries = args.num_queries, dropout_rate=args.dropout_rate, mlp=args.no_mlp ,pos_every=args.pos_every,no_pos=args.no_pos,no_norm=args.no_norm, nbits=nbits)
        # import pdb; pdb.set_trace()
        self.tail = nn.ModuleList([
            nn.Sequential(
                common_qipt.Upsampler(common_qipt.default_conv, s, n_feats, act=False),
                #common_qipt.Upsampler(conv, s, n_feats, act=False),
                common_qipt.default_conv(n_feats, args.n_colors, kernel_size)
            ) for s in args.scale
        ])
        

    def forward(self, x, con=False):
        # import pdb; pdb.set_trace()
        x = self.sub_mean(x)
        x = self.head[self.scale_idx](x)

        if not con:
            res = self.body(x, self.scale_idx)
        else:
            res, x_con = self.body(x, self.scale_idx, con)
        
        # res = self.body(x, self.scale_idx, con)
        res += x
        
        # import pdb; pdb.set_trace()
        x = self.tail[self.scale_idx](res)
        x = self.add_mean(x)
        
        if not con:
            return x
        else:
            return x, x_con
        # return x 

    def set_scale(self, scale_idx):
        self.scale_idx = scale_idx
        
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_dim,
        patch_dim,
        num_channels,
        embedding_dim,
        num_heads,
        num_layers,
        hidden_dim,
        num_queries,
        positional_encoding_type="learned",
        dropout_rate=0,
        no_norm=False,
        mlp=False,
        pos_every=False,
        no_pos = False,
        nbits=4
    ):
        super(VisionTransformer, self).__init__()

        assert embedding_dim % num_heads == 0
        assert img_dim % patch_dim == 0
        self.no_norm = no_norm
        self.mlp = mlp
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.patch_dim = patch_dim
        self.num_channels = num_channels
        
        self.img_dim = img_dim
        self.pos_every = pos_every
        self.num_patches = int((img_dim // patch_dim) ** 2)
        self.seq_length = self.num_patches
        self.flatten_dim = patch_dim * patch_dim * num_channels
        
        self.out_dim = patch_dim * patch_dim * num_channels
        
        self.no_pos = no_pos
        
        if self.mlp==False:
            self.linear_encoding = LinearLSQ(self.flatten_dim, embedding_dim, nbits_w=nbits)
            self.mlp_head = nn.Sequential(
                LinearLSQ(embedding_dim, hidden_dim, nbits_w=nbits),
                nn.Dropout(dropout_rate),
                nn.ReLU(),
                LinearLSQ(hidden_dim, self.out_dim, nbits_w=nbits),
                nn.Dropout(dropout_rate)
            )

            # import pdb; pdb.set_trace()
            self.query_embed = nn.Embedding(num_queries, embedding_dim * self.seq_length)

        encoder_layer = TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm, nbits=nbits)
        self.encoder = TransformerEncoder(encoder_layer, num_layers)
        
        decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm, nbits=nbits)
        self.decoder = TransformerDecoder(decoder_layer, num_layers)
        
        if not self.no_pos:
            self.position_encoding = LearnedPositionalEncoding(
                    self.seq_length, self.embedding_dim, self.seq_length
                )
            
        self.dropout_layer1 = nn.Dropout(dropout_rate)
        
        if no_norm:
            for m in self.modules():
                if isinstance(m, LinearLSQ):
                    nn.init.normal_(m.weight, std = 1/m.weight.size(1))

    def forward(self, x, query_idx, con=False):

        x = torch.nn.functional.unfold(x,self.patch_dim,stride=self.patch_dim).transpose(1,2).transpose(0,1).contiguous()
               
        if self.mlp==False:
            x = self.dropout_layer1(self.linear_encoding(x, task=query_idx)) + x

            query_embed = self.query_embed.weight[query_idx].view(-1,1,self.embedding_dim).repeat(1,x.size(1), 1)
            # query_embed = self.query_embed.weight[0].view(-1,1,self.embedding_dim).repeat(1,x.size(1), 1)
        else:
            query_embed = None

        
        if not self.no_pos:
            pos = self.position_encoding(x).transpose(0,1)

        if self.pos_every:
            x = self.encoder(x, task=query_idx, pos=pos)
            x = self.decoder(x, x, task=query_idx, pos=pos, query_pos=query_embed)
        elif self.no_pos:
            x = self.encoder(x, task=query_idx)
            x = self.decoder(x, x, task=query_idx, query_pos=query_embed)
        else:
            x = self.encoder(x+pos, task=query_idx)
            x = self.decoder(x, x, task=query_idx, query_pos=query_embed)
        
        
        if self.mlp==False:
            res = x
            for layer in self.mlp_head:
                if isinstance(layer, LinearLSQ):
                    x = layer(x, task=query_idx)
                else:
                    x = layer(x)
            x = x + res
        
        x = x.transpose(0,1).contiguous().view(x.size(1), -1, self.flatten_dim)
        
        # import pdb; pdb.set_trace()
        if con:
            con_x = x
            x = torch.nn.functional.fold(x.transpose(1,2).contiguous(),int(self.img_dim),self.patch_dim,stride=self.patch_dim)
            return x, con_x
        
        x = torch.nn.functional.fold(x.transpose(1,2).contiguous(),int(self.img_dim),self.patch_dim,stride=self.patch_dim)
        
        return x

class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_position_embeddings, embedding_dim, seq_length):
        super(LearnedPositionalEncoding, self).__init__()
        self.pe = nn.Embedding(max_position_embeddings, embedding_dim)
        self.seq_length = seq_length

        self.register_buffer(
            "position_ids", torch.arange(self.seq_length).expand((1, -1))
        )

    def forward(self, x, position_ids=None):
        if position_ids is None:
            position_ids = self.position_ids[:, : self.seq_length]

        position_embeddings = self.pe(position_ids)
        return position_embeddings
    
class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, src, task, pos = None):
        output = src

        for layer in self.layers:
            output = layer(output, task, pos=pos)

        return output
    
class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False,
                 activation="relu", nbits=4):
        super().__init__()
               
        self.self_attn = QuantMultiheadAttention(d_model, nhead, n_bit=nbits, dropout=dropout, bias=False, encoder=False)
        # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
        # self.self_attn.out_proj = Quant.LinearQ(d_model, d_model, bias=False, nbits_w=4)
        # self.q_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        # self.k_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        # self.v_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        # Implementation of Feedforward model
        self.linear1 = LinearLSQ(d_model, dim_feedforward, nbits_w=nbits)
        # self.linear1 = Quant.LinearQ(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = LinearLSQ(dim_feedforward, d_model, nbits_w=nbits)
        # self.linear2 = Quant.LinearQ(dim_feedforward, d_model)
        
        self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
        self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        
        nn.init.kaiming_uniform_(self.self_attn.in_proj_weight, a=math.sqrt(5))

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos
    
    def forward(self, src, task, pos = None):
        # import pdb; pdb.set_trace()
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        # q = self.q_act(q)
        # k = self.k_act(k)
        # src2 = self.v_act(src2)
        src2 = self.self_attn(q, k, src2, task=task)
        src = src + self.dropout1(src2[0])
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2, task=task))), task=task)
        src = src + self.dropout2(src2)
        return src

    
class TransformerDecoder(nn.Module):

    def __init__(self, decoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, tgt, memory, task, pos = None, query_pos = None):
        output = tgt
        
        for layer in self.layers:
            output = layer(output, memory, task, pos=pos, query_pos=query_pos)

        return output

    
class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm = False,
                 activation="relu", nbits=4):
        super().__init__()
        
   
        # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
        # self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
        self.self_attn = QuantMultiheadAttention(d_model, nhead, n_bit=nbits, dropout=dropout, bias=False, encoder=False)
        #self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
        #self.self_attn.out_proj = Quant.LinearQ(d_model, d_model, bias=False, nbits_w=4)
        #self.q_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        #self.k_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        #self.v_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        self.multihead_attn = QuantMultiheadAttention(d_model, nhead, n_bit=nbits, dropout=dropout, bias=False, encoder=True)
        #self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
        #self.multihead_attn.out_proj = Quant.LinearQ(d_model, d_model, bias=False, nbits_w=4)
        #self.mq_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        #self.mk_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        #self.mv_act = Quant.ActQ(nbits_a=4, in_features=d_model)
        # Implementation of Feedforward model
        self.linear1 = LinearLSQ(d_model, dim_feedforward, nbits_w=nbits)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = LinearLSQ(dim_feedforward, d_model, nbits_w=nbits)
        #self.linear1 = Quant.LinearQ(d_model, dim_feedforward)
        #self.dropout = nn.Dropout(dropout)
        #self.linear2 = Quant.LinearQ(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
        self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
        self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward(self, tgt, memory, task, pos = None, query_pos = None):
        # import pdb; pdb.set_trace()
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        #q = self.q_act(q)
        #k = self.k_act(k)
        #tgt2 = self.v_act(tgt2)
        tgt2 = self.self_attn(q, k, value=tgt2, task=task)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt2 = self.norm2(tgt)
        query = self.with_pos_embed(tgt2, query_pos)
        key = self.with_pos_embed(memory, pos)
        value = memory
        #query = self.mq_act(query)
        #key = self.mk_act(key)
        #value = self.mv_act(value)
        tgt2 = self.multihead_attn(query,
                                   key,
                                   value=value, task=task)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2, task=task))), task=task)
        tgt = tgt + self.dropout3(tgt2)
        return tgt


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
