import torch
import time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.nn import Parameter
from copy import deepcopy

from .quantizer import UniformQuantizer, LogSqrt2Quantizer


class QuantConv2d(nn.Conv2d):
    """
    Class to quantize weights of given convolutional layer
    """
    def __init__(self,   
                in_channels,
                out_channels,
                kernel_size,
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=True,
                input_quant_params={},
                weight_quant_params={}):
        super(QuantConv2d, self).__init__(in_channels=in_channels,
                                          out_channels=out_channels,
                                          kernel_size=kernel_size,
                                          stride=stride,
                                          padding=padding,
                                          dilation=dilation,
                                          groups=groups,
                                          bias=bias)

        input_quant_params_conv = deepcopy(input_quant_params)
        input_quant_params_conv['n_bits'] = 8
        input_quant_params_conv['diaq'] = False
        self.input_quantizer = UniformQuantizer(**input_quant_params_conv)
        self.weight_quantizer = UniformQuantizer(**weight_quant_params)

        self.use_input_quant = False
        self.use_weight_quant = False

    def __repr__(self):
        s = super(QuantConv2d, self).__repr__()
        s = "(" + s + "input_quant={}, weight_quant={})".format(self.use_input_quant, self.use_weight_quant)
        return s
    
    def set_quant_state(self, input_quant=False, weight_quant=False):
        self.use_input_quant = input_quant
        self.use_weight_quant = weight_quant

    def forward(self, x, **kwargs):
        """
        using quantized weights to forward input x
        """
        if self.use_input_quant:
            x = self.input_quantizer(x)

        if self.use_weight_quant:
            w = self.weight_quantizer(self.weight)
        else:
            w = self.weight

        out = F.conv2d(
            x, 
            w, 
            self.bias, 
            self.stride, 
            self.padding, 
            self.dilation, 
            self.groups
        )

        return out

def rel_l2(x, y, dim=-1):
    diff = torch.norm(x - y, dim=dim)
    denom = torch.norm(y, dim=dim)
    rel = diff / denom
    return rel.sum().item(), rel.numel()

def coserr(x, y, dim=-1):
    dot = torch.sum(x * y, dim=dim)
    cos = dot / (torch.norm(x, dim=dim) * torch.norm(y, dim=dim)+1e-14)
    err = 1 - cos
    return err.sum().item(), err.numel()


class QuantLinear(nn.Linear):
    """
    Class to quantize weights of given Linear layer
    """
    def __init__(self,
                 in_features,
                 out_features,
                 input_quant_params={},
                 weight_quant_params={}):
        super(QuantLinear, self).__init__(in_features, out_features)

        self.input_quantizer = UniformQuantizer(**input_quant_params)
        self.weight_quantizer = UniformQuantizer(**weight_quant_params)

        self.diaq = input_quant_params['diaq']

        self.use_input_quant = False
        self.use_weight_quant = False

        self.tracking_err = False

    def __repr__(self):
        s = super(QuantLinear, self).__repr__()
        s = "(" + s + "input_quant={}, weight_quant={})".format(self.use_input_quant, self.use_weight_quant)
        return s
    
    def set_quant_state(self, input_quant=False, weight_quant=False):
        self.use_input_quant = input_quant
        self.use_weight_quant = weight_quant
    
    def set_diaq(self, diaq=False):
        self.diaq = diaq
        self.input_quantizer.diaq = diaq

    def track_err(self, err_dict=None):
        if err_dict is None:
            self.tracking_err = False
            return
            
        self.tracking_err = True
        self.err_dict = err_dict

    def forward(self, x, **kwargs):
        """
        using quantized weights to forward input x
        """
        if self.diaq:
            xlen = torch.norm(x, p=2, dim=-1, keepdim=True)
        
        if self.tracking_err:
            xraw = x.clone()
            xrtn = self.input_quantizer.rtn(x)

        if self.use_input_quant:
            x = self.input_quantizer(x)

        if self.diaq:
            xqlen = torch.norm(x, p=2, dim=-1, keepdim=True)
            # print(f'x: {x.shape}, xlen: {xlen.shape}, xqlen: {xqlen.shape}')
            x = x * (xlen / (xqlen+1e-12))

        if self.use_weight_quant:
            w = self.weight_quantizer(self.weight)
        else:
            w = self.weight

        out = F.linear(x, weight=w, bias=self.bias)
        
        if self.tracking_err:
            yraw = F.linear(xraw, weight=w, bias=self.bias)
            yrtn = F.linear(xrtn, weight=w, bias=self.bias)
            xdia = x
            ydia = out

            self.err_dict['x_rtn_l2'][0] += rel_l2(xrtn, xraw)[0]
            self.err_dict['x_rtn_l2'][1] += rel_l2(xrtn, xraw)[1]
            self.err_dict['y_rtn_l2'][0] += rel_l2(yrtn, yraw)[0]
            self.err_dict['y_rtn_l2'][1] += rel_l2(yrtn, yraw)[1]
            self.err_dict['x_rtn_cos'][0] += coserr(xrtn, xraw)[0]
            self.err_dict['x_rtn_cos'][1] += coserr(xrtn, xraw)[1]
            self.err_dict['y_rtn_cos'][0] += coserr(yrtn, yraw)[0]
            self.err_dict['y_rtn_cos'][1] += coserr(yrtn, yraw)[1]
            self.err_dict['x_diaq_l2'][0] += rel_l2(xdia, xraw)[0]
            self.err_dict['x_diaq_l2'][1] += rel_l2(xdia, xraw)[1]
            self.err_dict['x_diaq_cos'][0] += coserr(xdia, xraw)[0]
            self.err_dict['x_diaq_cos'][1] += coserr(xdia, xraw)[1]
            self.err_dict['y_diaq_l2'][0] += rel_l2(ydia, yraw)[0]
            self.err_dict['y_diaq_l2'][1] += rel_l2(ydia, yraw)[1]
            self.err_dict['y_diaq_cos'][0] += coserr(ydia, yraw)[0]
            self.err_dict['y_diaq_cos'][1] += coserr(ydia, yraw)[1]

        return out
        

class QuantMatMul(nn.Module):
    """
    Class to quantize weights of given Linear layer
    """
    def __init__(self,
                 input_quant_params={}):
        super(QuantMatMul, self).__init__()

        input_quant_params_matmulA = deepcopy(input_quant_params)
        input_quant_params_matmulB = deepcopy(input_quant_params)
        
        if 'log_quant' in input_quant_params_matmulA:
            input_quant_params_matmulA.pop('log_quant')
            input_quant_params_matmulB.pop('log_quant')
            input_quant_params_matmulA['diaq'] = False
            self.quantizer_A = LogSqrt2Quantizer(**input_quant_params_matmulA)
            self.scaleA = False
            self.uniA = False
        else:
            self.quantizer_A = UniformQuantizer(**input_quant_params_matmulA)
            self.scaleA = input_quant_params_matmulA['diaq']
            self.uniA = True
        self.quantizer_B = UniformQuantizer(**input_quant_params_matmulB)
        self.scaleB = input_quant_params_matmulB['diaq']

        self.use_input_quant = False
        self.tracking_err = False

    def __repr__(self):
        s = super(QuantMatMul, self).__repr__()
        s = "(" + s + "input_quant={})".format(self.use_input_quant)
        return s
    
    def set_quant_state(self, input_quant=False, weight_quant=False):
        self.use_input_quant = input_quant
    
    def set_diaq(self, diaq=False):
        self.diaq = diaq
        if self.uniA:
            self.quantizer_A.diaq = diaq
            self.quantizer_A.diaq_dim = -1
            self.scaleA = diaq
        self.quantizer_B.diaq = diaq
        self.quantizer_B.diaq_dim = -2
        self.scaleB = diaq
    
    def track_err(self, err_dict=None):
        if err_dict is None:
            self.tracking_err = False
            return
            
        self.tracking_err = True
        self.err_dict = err_dict

    def forward(self, A, B, **kwargs):
        if self.tracking_err:
            if self.uniA:
                Araw = A.clone()
                Artn = self.quantizer_A.rtn(A)
            else:
                Araw = self.quantizer_A(A)
                Artn = self.quantizer_A(A)
            Braw = B.clone()
            Brtn = self.quantizer_B.rtn(B)
            Yraw = Araw @ Braw
            Yrtn = Artn @ Brtn
            if self.uniA:
                Yraw = Yraw.reshape(*Yraw.shape[:-2], -1)
                Yrtn = Yrtn.reshape(*Yrtn.shape[:-2], -1)

        if self.scaleA:
            Alen = torch.norm(A, p=2, dim=-1, keepdim=True)
        if self.scaleB:
            Blen = torch.norm(B, p=2, dim=-2, keepdim=True)

        if self.use_input_quant:
            A = self.quantizer_A(A)
            B = self.quantizer_B(B)

        if self.scaleA:
            Aqlen = torch.norm(A, p=2, dim=-1, keepdim=True)+1e-12
        if self.scaleB:
            Bqlen = torch.norm(B, p=2, dim=-2, keepdim=True)+1e-12
        
        out = A @ B

        if self.scaleA:
            out *= (Alen / Aqlen)
        if self.scaleB:
            out *= (Blen / Bqlen)

        if self.tracking_err:
            Ydia = out.clone()
            if self.uniA:
                Ydia = Ydia.reshape(*Ydia.shape[:-2], -1)

            if self.uniA:
                A_rtn_l2 = rel_l2(Artn, Araw, dim=-1)
                A_rtn_cos = coserr(Artn, Araw, dim=-1)
                A_dia_l2 = rel_l2(A*(Alen/Aqlen), Araw, dim=-1)
                A_dia_cos = coserr(A, Araw, dim=-1)
            B_rtn_l2 = rel_l2(Brtn, Braw, dim=-2)
            B_rtn_cos = coserr(Brtn, Braw, dim=-2)
            B_dia_l2 = rel_l2(B*(Blen/Bqlen), Braw, dim=-2)
            B_dia_cos = coserr(B, Braw, dim=-2)

            if self.uniA:
                x_rtn_l2 = (A_rtn_l2[0] + B_rtn_l2[0], A_rtn_l2[1] + B_rtn_l2[1])
                x_rtn_cos = (A_rtn_cos[0] + B_rtn_cos[0], A_rtn_cos[1] + B_rtn_cos[1])
                x_dia_l2 = (A_dia_l2[0] + B_dia_l2[0], A_dia_l2[1] + B_dia_l2[1])
                x_dia_cos = (A_dia_cos[0] + B_dia_cos[0], A_dia_cos[1] + B_dia_cos[1])
            else:
                x_rtn_l2 = B_rtn_l2
                x_rtn_cos = B_rtn_cos
                x_dia_l2 = B_dia_l2
                x_dia_cos = B_dia_cos

            y_rtn_l2 = rel_l2(Yrtn, Yraw)
            y_rtn_cos = coserr(Yrtn, Yraw)
            y_dia_l2 = rel_l2(Ydia, Yraw)
            y_dia_cos = coserr(Ydia, Yraw)

            self.err_dict['x_rtn_l2'][0] +=  x_rtn_l2[0]
            self.err_dict['x_rtn_l2'][1] +=  x_rtn_l2[1]
            self.err_dict['y_rtn_l2'][0] +=  y_rtn_l2[0]
            self.err_dict['y_rtn_l2'][1] +=  y_rtn_l2[1]
            self.err_dict['x_rtn_cos'][0] +=  x_rtn_cos[0]
            self.err_dict['x_rtn_cos'][1] +=  x_rtn_cos[1]
            self.err_dict['y_rtn_cos'][0] +=  y_rtn_cos[0]
            self.err_dict['y_rtn_cos'][1] +=  y_rtn_cos[1]
            self.err_dict['x_diaq_l2'][0] +=  x_dia_l2[0]
            self.err_dict['x_diaq_l2'][1] +=  x_dia_l2[1]
            self.err_dict['x_diaq_cos'][0] +=  x_dia_cos[0]
            self.err_dict['x_diaq_cos'][1] +=  x_dia_cos[1]
            self.err_dict['y_diaq_l2'][0] +=  y_dia_l2[0]
            self.err_dict['y_diaq_l2'][1] +=  y_dia_l2[1]
            self.err_dict['y_diaq_cos'][0] +=  y_dia_cos[0]
            self.err_dict['y_diaq_cos'][1] +=  y_dia_cos[1]

        return out
