import torch
import torch.nn as nn
import torch.nn.functional as F


def list2vec(z1_list):
    bsz = z1_list[0].size(0)
    return torch.cat([elem.reshape(bsz, -1, 1) for elem in z1_list], dim=1)


def vec2list(z1, cutoffs):
    bsz = z1.shape[0]
    z1_list = []
    start_idx, end_idx = 0, cutoffs[0][0] * cutoffs[0][1] * cutoffs[0][2]
    for i in range(len(cutoffs)):
        z1_list.append(z1[:, start_idx:end_idx].view(bsz, *cutoffs[i]))
        if i < len(cutoffs)-1:
            start_idx = end_idx
            end_idx += cutoffs[i + 1][0] * cutoffs[i + 1][1] * cutoffs[i + 1][2]
    return z1_list


class AlphaNet(nn.Module):
    def __init__(self, nhid, learn_beta, ninner=100, alpha_rnn=False):
        super().__init__()
        self.nhid = nhid
        self.ninner = ninner
        self.learn_beta = learn_beta
        self.seq_fc = None   # Should be set by subclass

        if alpha_rnn:
            self.alpha_rnn = nn.GRU(ninner, ninner, batch_first=True, bidirectional=True)   # bi-directional?
            hidden_factor = 2
        else:
            self.alpha_tcn = nn.Sequential(nn.Conv1d(ninner, ninner, kernel_size=5, padding=2), nn.GroupNorm(5, ninner), nn.ReLU(),
                                           nn.Conv1d(ninner, ninner, kernel_size=5, padding=2))
            hidden_factor = 1

        self.alpha_predictor = nn.Sequential(nn.Linear(ninner*hidden_factor, ninner),
                                             nn.ReLU(),
                                             nn.Linear(ninner, 1))
        if learn_beta:
            print("Learning beta by beta=beta(G)!")
            self.beta_predictor = nn.Sequential(nn.Linear(ninner*hidden_factor, ninner), nn.ReLU(),
                                                nn.Linear(ninner, 1))
        
    def forward_RR(self, G, **kwargs):
        raise NotImplemented("This function is not implemented in the abstract class AlphaNet")
    
    def forward(self, RR, n, k, m, up):
        bsz = RR.shape[0]
        # 1. Ensure RR is chronologically organized
        if not self.training:
            # Re-organize the order because the indexing wraps around in eval mode
            if k >= m: RR = torch.cat([RR[:,up:], RR[:,:up]], dim=1)

        # 2. Predict alpha (chronological order)
        if "alpha_tcn" in self.__dict__['_modules']:
            RR = RR.clone().transpose(1,2)
            raw_ab = self.alpha_tcn(RR) + RR
            beta = self.beta_predictor(F.avg_pool1d(raw_ab, kernel_size=raw_ab.shape[-1]).squeeze(-1)).view(bsz,1) if self.learn_beta else None
            raw_ab = raw_ab.transpose(1,2)
        else:
            raw_ab, _ = self.alpha_rnn(RR.clone())
            raw_ab_bidir = raw_ab.view(bsz, -1, 2, self.ninner)
            beta = self.beta_predictor(torch.cat([raw_ab_bidir[:,-1,0], raw_ab_bidir[:,0,1]], dim=-1)).view(bsz,1) if self.learn_beta else None
        raw_alpha = self.alpha_predictor(raw_ab).view(bsz, n)
        # Option 1: Have the model learn by itself
        # alpha = raw_alpha
        # Option 2: Softmax (which ensures all between 0 and 1, which may be bad)
        # alpha = raw_alpha.softmax(dim=1)
        # Option 3: Normalize by shifting
        alpha = raw_alpha + (1 - raw_alpha.sum(1, keepdim=True)) / n     # Make sure alpha sums to 1. Shape: (bsz x n)

        # 3. Put alpha back to wrap-around order if needed
        if not self.training:
            # Re-organize the order because the indexing wraps around in eval mode
            if k >= m: alpha = torch.cat([alpha[:,(n-up):], alpha[:,:(n-up)]], dim=1)
        return alpha, beta


class SequenceAlphaNet(AlphaNet):
    def __init__(self, nhid, learn_beta, ninner=100, alpha_rnn=False):
        super(SequenceAlphaNet, self).__init__(nhid=nhid, ninner=ninner, learn_beta=learn_beta)
        self.seq_fc = nn.Sequential(nn.Linear(nhid, ninner), nn.ReLU(), nn.Linear(ninner, ninner))
    
    def forward_RR(self, G, **kwargs):
        out = self.seq_fc(G[:,:,-1])
        return out


class MultiscaleAlphaNet(AlphaNet):
    def __init__(self, nhid, learn_beta, ninner=100, alpha_rnn=False):
        super(MultiscaleAlphaNet, self).__init__(nhid=nhid, ninner=ninner, learn_beta=learn_beta)
        self.seq_fc = nn.Sequential(nn.Conv2d(nhid, ninner, kernel_size=3), nn.GroupNorm(4, ninner), nn.ReLU(),
                                    nn.Conv2d(ninner, ninner, kernel_size=3))
    
    def forward_RR(self, G, cutoffs=None, **kwargs):
        assert cutoffs is not None, "You have to provide cutoffs if you were to use SequenceAlphaNet!"
        G_list = vec2list(G, cutoffs)
        G_low = self.seq_fc(G_list[-1])
        bsz = G_low.shape[0]
        return F.avg_pool2d(G_low, kernel_size=G_low.shape[-2:]).view(bsz, self.ninner)


class LearnableAnderson(nn.Module):
    def __init__(self, alpha_net_dict, m=5, stop_mode='rel', learn_alpha=True, alpha_nhid=200, learn_beta=False, hyperload=""):
        """[summary]

        Args:
            alpha_net_dict (dict): The type of alpha prediction network to use. This will differ based on data type 
                                   (e.g., sequence, image, feature tensor, etc.), and has the following format:
                                   {
                                       'name': [CLS_NAME],
                                       'kwargs': {...}
                                   }
                                   This can be None if learn_alpha=False.
            m (int, optional): [Number of Anderson's past-step slots]. Defaults to 6.
            stop_mode (str, optional): [Whether residuals measured in absolute ("abs") or relative ("rel") mode]. 
                                       Defaults to 'rel'.
            learn_alpha (bool, optional): [If True, alpha of Anderson will be learned]. Defaults to True.
            alpha_nhid (int, optional): [Input dimension used to predict alpha]. Defaults to 200.
            learn_beta (bool, optional): [If True, beta of Anderson will be learned]. Defaults to False.
            hyperload (str, optional): [Path to load a pretrained hypersolver state dict]. Defaults to "".
        """
        super().__init__()
        self.stop_mode = 'rel'
        self.m = m
        assert m > 2, "You should have m > 2 to satisfy AA prototype"
        self.alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
        self.learn_alpha = learn_alpha
        self.learn_beta = learn_beta
        if learn_alpha:
            assert alpha_net_dict is not None and 'name' in alpha_net_dict, "alpha_net_dict CANNOT be None if learn_alpha"
            self.alpha_nhid = alpha_nhid
            self.alpha_net = eval(alpha_net_dict['name'])(alpha_nhid, learn_beta, **alpha_net_dict['kwargs'])
            if not learn_beta:
                print("Not learning beta!")
                self.beta = [1.0]*100

        if (not learn_alpha):
            if learn_beta:
                self.beta = nn.Parameter(torch.zeros(100,)+1.0)
            else:
                print("Not learning beta!")
                self.beta = [1.0]*100
            
        if len(hyperload) > 0:
            self.load_state_dict(torch.load(hyperload))

    def forward(self, f, x0, lam=1e-4, tol=1e-3, threshold=30, print_intermediate=False, print_galphas=True, **kwargs):
        """[summary]

        Args:
            f (function): [layer's function form]
            x0 (torch.Tensor): [Initial estimate of the fixed point]
            lam (float, optional): [Anderson's lambda]. Defaults to 1e-4.
            tol (float, optional): [Anderson's tolerance level; works with the stop_mode]. Defaults to 1e-3.
            threshold (int, optional): [Max number of forward iterations]. Defaults to 30.
            print_intermediate (bool, optional): [If True, returns intermediate estimates of the convergence]. Defaults to False.
            print_galphas (bool, optional): [If True, returns the ||G*alpha|| values]. Defaults to True.
            kwargs: [Cutoffs, etc. extra information that is to be passed into `forward_RR`]

        Returns:
            [dict]: [The result of fixed point solving by (learnable) Anderson]
        """
        bsz, d, L = x0.shape
        self.func.eval()
        m = self.m
        X = torch.zeros(bsz, threshold if self.training else m, d*L, dtype=x0.dtype, device=x0.device)
        F = torch.zeros(bsz, threshold if self.training else m, d*L, dtype=x0.dtype, device=x0.device)
        X[:,0], F[:,0] = x0.reshape(bsz, -1), f(x0).reshape(bsz, -1)
        if (x0 == 0).all():
            # Started with all zeros (i.e., no initializer)
            X[:,1], F[:,1] = F[:,0], f(F[:,0].reshape_as(x0)).reshape(bsz, -1)
            k_start = 2
        else:
            k_start = 1

        if self.learn_alpha:
            if "alpha_rnn" in self.alpha_net.__dict__['_modules']:
                self.alpha_net.alpha_rnn.flatten_parameters()
            dim = self.alpha_net.ninner
            R = torch.zeros(bsz, threshold if self.training else m, dim, dtype=x0.dtype, device=x0.device)
            for i in range(k_start):
                R[:,i] = self.alpha_net.forward_RR((F[:,i] - X[:,i]).view_as(x0), **kwargs).reshape(bsz, -1)
        else:
            H = torch.zeros(bsz, m+1, m+1, dtype=x0.dtype, device=x0.device)
            H[:,0,1:] = H[:,1:,0] = 1
            y = torch.zeros(bsz, m+1, 1, dtype=x0.dtype, device=x0.device)
            y[:,0] = 1

        trace_dict = {'abs': [], 'rel': []}
        F = F.detach()
        X = X.detach()
        Inits = Galphas = None
        Alphas = []
        if print_intermediate:
            Inits = [X[:,1].clone().detach()[None]]
        if print_galphas:   
            Galphas = []

        for k in range(k_start, threshold):
            n = min(k, m)
            # X[:,max(k-m,0):k].register_hook(lambda grad: generic_back_hook(grad, f"X[:,max(k-m,0):k], k={k}, m={m}"))
            if self.training:
                FF, XX, up = F[:,max(k-m,0):k], X[:,max(k-m,0):k] , k
                if self.learn_alpha: 
                    RR = R[:,max(k-m,0):k]
                else:
                    FF, XX = torch.cat([FF[:,n-(k%m):], FF[:,:n-(k%m)]], dim=1), torch.cat([XX[:,n-(k%m):], XX[:,:n-(k%m)]], dim=1)
            else:
                FF, XX, up = F[:,:n], X[:,:n], k%m
                if self.learn_alpha: 
                    RR = R[:,:n]
            G = FF-XX   # chronological order if learn_alpha; wrap-around if not.
            
            if self.learn_alpha:
                # New learnable Anderson
                alpha, beta = self.alpha_net(RR, n, k, m, up)
                # if k == threshold-1 and torch.distributed.get_rank() == 0:
                #     print(alpha, beta)
                if not self.learn_beta:
                    beta = self.beta[k]
                if print_galphas:
                    Galphas.append(torch.einsum('bni,bn->bi', G, alpha).norm(1).mean(0))
            else:
                # Original Anderson
                H[:,1:n+1,1:n+1] = torch.bmm(G,G.transpose(1,2)) + lam*torch.eye(n, dtype=x0.dtype,device=x0.device)[None]
                alpha = torch.solve(y[:,:n+1].clone(), H[:,:n+1,:n+1].clone())[0][:,1:n+1,0]   # (bsz x n)
                beta = self.beta[k]

            # The forward pass here uses X[:,:n] and F[:,:n], which are CHANGED later in-place (i.e., X[:,k%m]=...)
            X[:,up] = beta * (alpha[:,None] @ FF)[:,0] + (1-beta)*(alpha[:,None] @ XX)[:,0]
            F[:,up] = f(X[:,up].reshape_as(x0)).reshape(bsz, -1)
            temp = F[:,up].clone()  # To avoid gradient inplace op error
            gx = (F[:,up] - X[:,up]).view_as(x0)
            if self.learn_alpha:
                R[:,up] = self.alpha_net.forward_RR(gx, **kwargs).reshape(bsz, -1)   # (bsz, ninner)
            
            abs_diff = gx.norm()
            rel_diff = abs_diff / (1e-9 + temp.norm())
            trace_dict['abs'].append(abs_diff)
            trace_dict['rel'].append(rel_diff)
            if print_intermediate:
                Inits.append(X[:,up].clone().detach()[None])
                Alphas.append(alpha.clone().detach()[None])
        
        return {'result': X[:,up].view_as(x0),
                'X': X.view(bsz, -1, *x0.shape[1:]),
                'Inits': torch.cat(Inits, dim=0) if Inits else None,
                'rel_trace': torch.stack(trace_dict['rel']).view(1,-1),
                'abs_trace': torch.stack(trace_dict['abs']).view(1,-1),
                'Galphas': torch.stack(Galphas).view(1,-1) if Galphas else None}
