from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import linalg as LA
import numpy as np
from abc import ABC, abstractclassmethod
from utils.display import wandb_mat_image, mat_message

TRANS_TYPES = {
    'class_matrix':['uniform', 'normal', 'pairwise', 'uncon', None],
}
# ['uniform', 'normal', 'pairwise', 'batch_normal', 'instance_normal']

def init_prob_transition(num_classes, trans_rate, trans_type):
    """
    Just uniform probability to other classes
    Args:
        trans_rate[float]: the rate of the samples to be noisy.
    Returns:
        T[torch.tensor, num_classes * num_classes], i=Y, j= \hat{Y}
        T_{ij} = P(\hat{Y}=j | Y=i), \sum_{j}T_{ij} = 1
    """
    assert trans_type in ['uniform', 'normal'], f'The type {trans_type} for prob transistion is not supported.'
    T = torch.eye(num_classes)

    if trans_type == 'uniform':
        noise = torch.ones(num_classes, num_classes) / num_classes
    elif trans_type == 'normal':
        noise = torch.abs(torch.rand(num_classes, num_classes))

    T = T * (1 - trans_rate) + noise * trans_rate
    # normalize
    T = T / T.sum(dim=1, keepdim=True)
    return T

def init_pairwise_transition(num_classes, trans_rate, trans_type):
    """
    Args:
        trans_rate[float]: the rate of the samples to be noisy.
    Returns:
        T[torch.tensor, num_classes * num_classes], i=Y, j= \hat{Y}
        T_{ij} = P(\hat{Y}=j | Y=i), \sum_{j}T_{ij} = 1
    """
    # Actually not used now, may implement other types
    assert trans_type in ['pairwise_next', 'pairwise_random'], f'The type {trans_type} for pairwise transistion is not supported.'
    ori_inds = torch.arange(num_classes)
    flip_inds = ori_inds
    if trans_type == 'pairwise_next':
        flip_inds = torch.arange(num_classes) + 1
        flip_inds[-1] = 0
    elif trans_type == 'pairwise_random':
        while (flip_inds == ori_inds).sum() > 0:
            flip_inds = torch.randperm(num_classes)     
    T_flip = torch.zeros(num_classes, num_classes)
    T_flip[ori_inds, flip_inds] = trans_rate
    return torch.eye(num_classes) * (1 - trans_rate) + T_flip

def init_transistion_matrix(num_classes, trans_rate, trans_type=None):
    assert trans_type in TRANS_TYPES['class_matrix'], f'The type {trans_type} for transistion is not supported.'

    if trans_type in ['uniform', 'normal']:
        return init_prob_transition(num_classes, trans_rate, trans_type)
    elif trans_type == 'pairwise':
        return init_pairwise_transition(num_classes, trans_rate, trans_type)
    elif trans_type is not None and trans_type.startswith('class'):
        cls_idx = int(trans_type.replace('class', ''))
        trans_matrix = torch.eye(num_classes, num_classes) + torch.ones(num_classes, num_classes) / num_classes
        trans_matrix[cls_idx, :] = 0
        trans_matrix[cls_idx, cls_idx] = 1
        trans_matrix = (trans_matrix) / trans_matrix.sum(dim=1, keepdim=True)
        return trans_matrix
    elif trans_type == 'uncon':
        return torch.ones(num_classes, num_classes)
    else:
        return torch.eye(num_classes)


def transition_distance(true_trans_func, pred_trans_func, dist_type='fro'):
    """Compute the distance between two transition function (Now matrics)"""
    if dist_type == 'fro':
        return torch.norm(true_trans_func.get_matrix()-pred_trans_func.get_matrix(), p='fro')
    elif dist_type == 'l1':
        t_matrix, p_matrix = true_trans_func.get_matrix(), pred_trans_func.get_matrix()
        return torch.sum(torch.abs(t_matrix - p_matrix)) / torch.sum(torch.abs(t_matrix))
    
    # return true_trans_func.compare(pred_trans_func) # may implement 
    

class TransitionFunction(nn.Module):
    """
    The functional of transition, Just a abstract class, not any implementation.
    Example:
        T(x, i, j) = T_ij(x), given a sample x, and class label i, j the functional give 
    """
    def __init__(self, num_classes, trans_rate, trans_type) -> None:
        super().__init__()
        self.num_classes = num_classes
        self.trans_rate = trans_rate
        self.trans_type = trans_type


    def forward(self, x, i, j=None) -> Any:
        return self._transit(x, i, j)
    
    def normalize(self, normalized_type='clamp'):
        pass

    @abstractclassmethod
    def _transit(self, x, i, j=None) -> int:
        """
        If j = None, it will return a vector of probability of each class to be transited.
        If j != None, it will return a value of probability of the j-th class to be transited.
        """
        pass


class ClassMatrixTransistion(TransitionFunction):
    """
    The transistion function that depend only on class_i not on x
    Do not use other normalized type now: 
        only use 'clamp'
    Example:
        T(x, i, j) = T_ij
    """
    def __init__(self, num_classes, trans_rate, trans_type=None, normalized_type='clamp') -> None:
        super().__init__(num_classes, trans_rate, trans_type)
        self._T = init_transistion_matrix(num_classes, trans_rate, trans_type)
        self._T = nn.Parameter(self._T)
        self.normalized_type = normalized_type

    def _transit(self, x, i=None, j=None) -> int:
        if j is None:
            return self._T[i]
        else:
            return self._T[i, j]
    
    def get_matrix(self):
        return self._T
    
    def wandb_visualization(self):
        T = self.get_matrix()
        return [wandb_mat_image(T)]

    def __str__(self):
        T = self.get_matrix()
        return mat_message(T)
    
    def normalize(self, normalized_type='clamp'):
        """Normalize the transition matrix"""
        if normalized_type is None:
            normalized_type = self.normalized_type
        if normalized_type == 'clamp':
            self._T.data = self._T.data.clamp_(0, 1)
        else:
            self._T.data = torch.sigmoid(self._T.data)
        self._T.data = self._T.data / self._T.data.sum(dim=1, keepdim=True)

    def transit_pred(self, X, X_features=None) -> torch.tensor:
        """
        Transition prediction based on the prediction and transition matrix.
        Args:
            X[torch.tensor, N x num_classes]: the prediction from the classifier.
        Return 
            X_hat[torch.tensor, N x num_classes]: transited prediction results.
        """
        return torch.matmul(X, self._T)

    # def compare(self, trans_func):
    #     if isinstance(trans_func, ClassMatrixTransistion):
    #         return torch.norm(self.get_matrix()-trans_func.get_matrix(), p='fro')
    #     elif isinstance(trans_func, KernelMatrixTransition):

class UnconClassMatrixTransistion(TransitionFunction):
    """
    Unconstrainted version, the same as the original implementation in ROBOT
    The transistion function that depend only on class_i not on x
    Example:
        T(x, i, j) = T_ij
    """
    def __init__(self, num_classes, trans_rate, trans_type=None, normalized_type='sigmoid') -> None:
        super().__init__(num_classes, trans_rate, trans_type)
        self._T = init_transistion_matrix(num_classes, trans_rate, trans_type) * (-4.5)
        self._T = nn.Parameter(self._T)
        self.co = torch.ones(num_classes, num_classes) - torch.eye(num_classes)
        self.identity = torch.eye(num_classes)
        self.co.requires_grad, self.identity.requires_grad = False, False
        self.normalized_type = normalized_type

    def _transit(self, x, i=None, j=None) -> int:
        if j is None:
            return self._T[i]
        else:
            return self._T[i, j]
    
    def get_matrix(self):
        with torch.no_grad():
            T = self._normalize(self._T)
        return T
    
    def wandb_visualization(self):
        T = self.get_matrix()
        return [wandb_mat_image(T)]

    def __str__(self):
        T = self.get_matrix()
        return mat_message(T)

    def _normalize(self, T, normalized_type=None):
        """Normalize the transition matrix"""
        # Only support sigmoid now.
        if normalized_type is None:
            normalized_type = self.normalized_type
        device = T.device
        if normalized_type == 'sigmoid':
            T = torch.sigmoid(T)
            T = self.identity.to(device) + T * self.co.to(device)
            T = F.normalize(T, p=1, dim=1)
        elif normalized_type == 'softmax':
            T = F.softmax(T, dim=1)
        return T

    def transit_pred(self, X, X_features=None) -> torch.tensor:
        """
        Transition prediction based on the prediction and transition matrix.
        Args:
            X[torch.tensor, N x num_classes]: the prediction from the classifier.
        Return 
            X_hat[torch.tensor, N x num_classes]: transited prediction results.
        """
        T = self._normalize(self._T)
        return torch.matmul(X, T)

def rbf_kernel(x, y, dim=1, sigma=1):
    return torch.exp(-LA.vector_norm(x-y, dim=dim, ord=2)**2 / (sigma**2))

def polynomial_kernel(x, y, dim=1, c=0, d=2):
    return (torch.sum(x * y, dim=dim) + c)**d

def cos_kernel(x, y, dim=1):
    return (1 + F.cosine_similarity(x, y, dim)) / 2

def exponential_kernel():
    pass


KERNEL_DICT = {
    'rbf':rbf_kernel,
    'exp':exponential_kernel,
    'poly':polynomial_kernel,
    'cos':cos_kernel

}

def is_kernel_transition(transition_func):
    return isinstance(transition_func, (KernelMatrixTransition, UnconKernelMatrixTransition))

class KernelMatrixTransition(TransitionFunction):
    """
    The transistion function that depend on class_i and x
    Example:
        T(x, i, j) = \sum T_ij * K(x - x_i)
    """
    def __init__(self, num_classes, num_features, kernel_type, trans_rate, trans_type=None, normalized_type='sigmoid') -> None:
        super().__init__(num_classes, trans_rate, trans_type)
        self._Ts, self._Xs = self._init_transition(num_classes, num_features, trans_rate, trans_type)
        self._Ts = nn.Parameter(self._Ts)
        self.normalized_type = normalized_type
        self._kernel_func = KERNEL_DICT[kernel_type]

    def _init_transition(self, num_classes, num_features, trans_rate, trans_type=None):
        """
        Return:
            Ts[torch.tensor, num_classes x num_classes x num_classes]
            Xs[torch.tensor, num_classes x num_features]
        """
        # do some naive implementation now
        if trans_type is not None and trans_type.startswith('class'):
            trans_types = [f'class{cls_i}' for cls_i in range(num_classes)]
        else:
            trans_types = [trans_type for cls_i in range(num_classes)]
        Ts = torch.stack([
            init_transistion_matrix(num_classes, trans_rate, trans_type_)
            for trans_type_ in trans_types
        ])
        Xs = nn.init.normal_(torch.zeros(num_classes, num_features))
        # Xs = torch.nn.init.normal_(torch.zeros(5, 5))
        return Ts, Xs
    
    def _transit(self, x, i, j=None, **kwargs):
        """
        Args:
            x[torch.tensor, batch_size x num_features]
        Return:
            T(i,j):[batch_size]
        """
        batch_size = x.size(0)
        # batch_size x num_class
        x_delta = self._kernel_func(x.unsqueeze(1), self._Xs.unsqueeze(0), dim=2, **kwargs)
        
        # normalize
        x_delta = x_delta / x_delta.sum(dim=1, keepdim=True)

        # _Ts: [ num_classes x num_classes x num_classes]
        if j is None:
            Tx = (x_delta.view(batch_size, -1, 1) * self._Ts.data[:,i,:].unsqueeze(0)).sum(dim=1)
        else:
            # Tx: [batch_size x num_classes x num_classes]
            Tx = (x_delta * self._Ts.data[:,i,j].unsqueeze(0)).sum(dim=1)
            return Tx

    def get_trans_x_error(self):
        num_classes = len(self._Xs)
        error = 0
        for i in range(num_classes):
            for j in range(i, num_classes):
                dist_ij = LA.norm(self._Xs[i] - self._Xs[j])
                error += dist_ij
        return error / (num_classes * (num_classes - 1) / 2)

    def update_Xs(self, pred_feats, preds, labels, gamma=0.1, use_pred=False):
        """update the features based on the predicted features 
        Moving average
        Args:
            pred_feats[torch.tensor, N x num_features]
            labels[torch.tensor, N]

        Return:
        """
        N = len(pred_feats)
        print("update num", N)
        if use_pred:
            labels = torch.argmax(preds, dim=1)
        feats_dict = {}
        # print(pred_feats.size(), labels.size())
        for i in range(N):
            label_tag = int(labels[i].detach())
            if label_tag in feats_dict:
                feats_dict[label_tag].append(pred_feats[i].detach())
            else:
                feats_dict[label_tag] = [pred_feats[i].detach()]
        updated_Xs = []
        num_classes = len(self._Xs)
        for i in range(num_classes):
            updated_Xs.append(torch.stack(feats_dict[i]).mean(dim=0))

        self._Xs = self._Xs.to(pred_feats.device) * (1-gamma) + torch.stack(updated_Xs) * gamma
        
        
    def get_matrix(self):
        return (self._Ts).sum(dim=0) / len(self._Ts)

    def wandb_visualization(self):
        with torch.no_grad():
            Ts = self._normalize(self._Ts)
        return [wandb_mat_image(T) for T in Ts]

    def __str__(self):
        with torch.no_grad():
            Ts = self._normalize(self._Ts)
        return '\n'.join([f'Class {i}:\n' + mat_message(T) for i, T in enumerate(Ts)])
    
    def normalize(self):
        """Normalize the transition matrix"""
        self._Ts.data = self._Ts.data.clamp_(0, 1)
        self._Ts.data = self._Ts.data / self._Ts.data.sum(dim=2, keepdim=True)

    def transit_pred(self, X, X_features, **kwargs) -> torch.tensor:
        """
        Transition prediction based on the prediction and transition matrix.
        Args:
            X[torch.tensor, batch_size x num_classes]: the prediction from the classifier.
            X_features[torch.tensor, batch_size x num_features]: the feature from the classifier.
        Return 
            X_hat[torch.tensor, batch_size x num_classes]: transited prediction results.
        """
        # print(self._Ts[0])
        # return torch.matmul(X, self._Ts[0])
        X_features = X_features.detach()
        batch_size = X_features.size(0)
        self._Xs = self._Xs.to(X.device)
        x_delta = self._kernel_func(X_features.unsqueeze(1), self._Xs.unsqueeze(0), dim=2, **kwargs)
        # normalize
        x_delta = x_delta / x_delta.sum(dim=1, keepdim=True)
        # x_delta:B x num_classes, Ts: num_classes x num_classes x num_classes
        # B x num_classes x num_classes
        num_classes = len(self._Ts)
        # Tx: batch_size x num_classes x num_classes
        Tx = torch.matmul(x_delta, self._Ts.view(num_classes, -1)).view(batch_size, num_classes, num_classes)
        return torch.bmm(X.view(batch_size, 1, -1), Tx).squeeze(1)

class UnconKernelMatrixTransition(TransitionFunction):
    """
    A sparse version for kernel based matrix transition. The transistion function that depend on class_i and x

    Example:
        T(x, i, j) = \cat T_j * K(x - x_i)
    """
    def __init__(self, num_classes, num_features, kernel_type, trans_rate, trans_type=None, normalized_type='sigmoid') -> None:
        super().__init__(num_classes, trans_rate, trans_type)
        self.normalized_type = normalized_type
        self._Ts, self._Xs = self._init_transition(num_classes, num_features, trans_rate, trans_type)
        self._Ts = nn.Parameter(self._Ts)
        self.co = torch.ones(num_classes, num_classes) - torch.eye(num_classes)
        self.identity = torch.eye(num_classes)
        self.fix_class = torch.zeros(num_classes, num_classes, num_classes)
        fix_idxs = torch.arange(num_classes)
        self.fix_class[fix_idxs, fix_idxs, fix_idxs] = 3
        self.co.requires_grad, self.identity.requires_grad, self.fix_class.requires_grad = False, False, False
        self._kernel_func = KERNEL_DICT[kernel_type]

    def _init_transition(self, num_classes, num_features, trans_rate, trans_type=None):
        """
        Return:
            Ts[torch.tensor, num_classes x num_classes]
            Xs[torch.tensor, num_classes x num_features]
        """
        # do some naive implementation now
        if 'sigmoid' in self.normalized_type:
            Ts = torch.stack([
                torch.ones(num_classes, num_classes) * -4.5 for i in range(num_classes)])
        elif 'softmax' in self.normalized_type:
            Ts = torch.stack([
                torch.eye(num_classes) * 4 for i in range(num_classes)])
        elif 'truncated' in self.normalized_type:
            Ts = torch.stack([
                torch.eye(num_classes) * 4 + torch.ones(num_classes, num_classes) for i in range(num_classes)])
        Xs = nn.init.normal_(torch.zeros(num_classes, num_features))
        # Xs = torch.nn.init.normal_(torch.zeros(5, 5))
        return Ts, Xs
    
    def _transit(self, x, i, j=None, **kwargs):
        """
        Args:
            x[torch.tensor, batch_size x num_features]
        Return:
            T(i,j):[batch_size]
        """
        batch_size = x.size(0)
        # batch_size x num_class
        x_delta = self._kernel_func(x.unsqueeze(1), self._Xs.unsqueeze(0), dim=2, **kwargs)
        
        # normalize
        x_delta = x_delta / x_delta.sum(dim=1, keepdim=True)

        # _Ts: [ num_classes x num_classes x num_classes]
        if j is None:
            Tx = (x_delta.view(batch_size, -1, 1) * self._Ts.data[:,i,:].unsqueeze(0)).sum(dim=1)
        else:
            # Tx: [batch_size x num_classes x num_classes]
            Tx = (x_delta * self._Ts.data[:,i,j].unsqueeze(0)).sum(dim=1)
            return Tx
        
    def get_matrix(self):
        with torch.no_grad():
            Ts = self._normalize(self._Ts)
        return (Ts).sum(dim=0) / len(Ts)
    
    def wandb_visualization(self):
        with torch.no_grad():
            Ts = self._normalize(self._Ts)
        return [wandb_mat_image(T) for T in Ts]

    def __str__(self):
        with torch.no_grad():
            Ts = self._normalize(self._Ts)
        return '\n'.join([f'Class {i}:\n' + mat_message(T) for i, T in enumerate(Ts)])
    
    def get_trans_x_error(self):
        num_classes = len(self._Xs)
        error = 0
        for i in range(num_classes):
            for j in range(i, num_classes):
                dist_ij = LA.norm(self._Xs[i] - self._Xs[j])
                error += dist_ij
        return error / (num_classes * (num_classes - 1) / 2)

    def update_Xs(self, pred_feats, preds, labels, gamma=0.1, use_pred=False):
        """update the features based on the predicted features 
        Moving average
        Args:
            pred_feats[torch.tensor, N x num_features]
            labels[torch.tensor, N]
        Return:
        """
        N = len(pred_feats)
        print("update num", N)
        if use_pred:
            labels = torch.argmax(preds, dim=1)
        feats_dict = {}
        # print(pred_feats.size(), labels.size())
        for i in range(N):
            label_tag = int(labels[i].detach())
            if label_tag in feats_dict:
                feats_dict[label_tag].append(pred_feats[i].detach())
            else:
                feats_dict[label_tag] = [pred_feats[i].detach()]
        updated_Xs = []
        num_classes = len(self._Xs)
        for i in range(num_classes):
            updated_Xs.append(torch.stack(feats_dict[i]).mean(dim=0))

        self._Xs = self._Xs.to(pred_feats.device) * (1-gamma) + torch.stack(updated_Xs) * gamma

    def normalize(self):
        pass
    
    def _normalize(self, Ts, normalized_type=None):
        """Normalize the transition matrix
        Args:
            Ts[num_classes x num_classes x num_classes]
        Return:
        Ts
        """
        if normalized_type is None:
            normalized_type = self.normalized_type
        # Only support sigmoid now.
        device = Ts.device
        num_classes = len(Ts)
        if 'fix' in normalized_type:
            Ts = Ts + self.fix_class.to(device)
        if 'sigmoid' in normalized_type:
            Ts = torch.sigmoid(Ts)
            Ts = self.identity.to(device).view(1, num_classes, num_classes) \
                + Ts * self.co.to(device).view(1, num_classes, num_classes)
            Ts = F.normalize(Ts, p=1, dim=2)
        elif 'softmax' in normalized_type:
            Ts = F.softmax(Ts, dim=2)
        elif 'truncated' in normalized_type:
            Ts = F.relu(Ts)
            Ts = F.normalize(Ts, p=1, dim=2)
        return Ts

    def transit_pred(self, X, X_features, **kwargs) -> torch.tensor:
        """
        Transition prediction based on the prediction and transition matrix.
        Args:
            X[torch.tensor, batch_size x num_classes]: the prediction from the classifier.
            X_features[torch.tensor, batch_size x num_features]: the feature from the classifier.
        Return 
            X_hat[torch.tensor, batch_size x num_classes]: transited prediction results.
        """
        # print(self._Ts[0])
        # return torch.matmul(X, self._Ts[0])
        X_features = X_features.detach()
        batch_size = X_features.size(0)
        self._Xs = self._Xs.to(X.device)
        x_delta = self._kernel_func(X_features.unsqueeze(1), self._Xs.unsqueeze(0), dim=2, **kwargs)
        # normalize
        x_delta = x_delta / x_delta.sum(dim=1, keepdim=True)
        # x_delta:B x num_classes, Ts: num_classes x num_classes x num_classes
        # B x num_classes x num_classes
        num_classes = len(self._Ts)
        Ts = self._normalize(self._Ts)
        # Tx: batch_size x num_classes x num_classes
        Tx = torch.matmul(x_delta, Ts.view(num_classes, -1)).view(batch_size, num_classes, num_classes)
        
        return torch.bmm(X.view(batch_size, 1, -1), Tx).squeeze(1)

