import torch.nn as nn
from src.modules.geo_modules import *
from src.chroma.layers.structure import sidechain
from src.modules.graph_transform import GraphTransform
from src.chroma.layers.structure.diffusion import DiffusionChainCov, ReconstructionLosses


class Encoder(nn.Module):
    def __init__(self, args,  **kwargs):
        super(Encoder, self).__init__()
        hidden_dim = args.hidden_dim
        self.node_emb = build_MLP(2, 19 * 4, hidden_dim, hidden_dim)
        self.edge_emb = build_MLP(2, 38 * 4 + 44, hidden_dim, hidden_dim)
        self.layers = nn.ModuleList([BlockGraphAttn(args.geo_layers, args.edge_layers, hidden_dim, dropout=args.dropout) for _ in range(args.enc_layers)])

    def forward(self, edge_idx, batch_id, V=None, E=None, T_ts=None, batch_id_extend=None, edge_idx_extend=None):
        batch_id, edge_idx = batch_id_extend, edge_idx_extend
        h_V, h_E = self.node_emb(V), self.edge_emb(E)
        for layer in self.layers:
            h_V, h_E = layer(h_V, h_E, T_ts, edge_idx=edge_idx, batch_id=batch_id)
        return h_V
    

class Decoder(nn.Module):
    def __init__(self, args, **kwargs):
        super(Decoder, self).__init__()
        self.args = args
        hidden_dim = args.hidden_dim
        self.embed_condition = nn.Linear(hidden_dim, hidden_dim)
        self.se3_layers = nn.ModuleList([GeneralE3GNN(1, hidden_dim, num_atoms=4) for _ in range(2)])
        self.chi_predictor = ChiPredictor(hidden_dim=hidden_dim)

    def forward(self, h_V, X, C, S, node_idx, batch_id, dec_topk=50):
        bb_preds = []
        for layer in self.se3_layers:
            X, _ = layer(X.detach(), h_V, batch_id, topk=dec_topk)
            bb_preds.append(X)

        # chi_probs, chi = self.chi_predictor(h_V.detach(), S)
        # return bb_preds, chi_probs, chi
        chi_probs, chi = torch.zeros(h_V.shape[0], 4, 20, device=h_V.device), torch.zeros(h_V.shape[0], 4, device=h_V.device)
        return bb_preds, chi_probs, chi 


class ChiPredictor(nn.Module):
    def __init__(self, num_alphabet=20, num_chi_bins=20, hidden_dim=128):
        super(ChiPredictor, self).__init__()

        self.num_chi_bins = num_chi_bins
        self.bins = torch.linspace(-np.pi, np.pi, self.num_chi_bins + 1)
        self.bin_centers = 0.5 * (self.bins[:-1] + self.bins[1:])
        
        self.W_S = nn.Embedding(num_alphabet, hidden_dim)
        self.W_chi = nn.Linear(4, hidden_dim)
        self.chi_predictor = nn.Linear(hidden_dim, 4 * num_chi_bins)   

    def forward(self, h_V, S):
        L = S.shape[0]
        h_V = h_V[:L] + self.W_S(S)

        chi_logits = self.chi_predictor(h_V)
        chi_logits = chi_logits.view(S.size(0), 4, self.num_chi_bins)
        chi_probs = F.softmax(chi_logits, dim=-1)

        chi = torch.sum(chi_probs * self.bin_centers.to(h_V.device), dim=-1)
        return chi_probs, chi


class RefinerPR(nn.Module):
    def __init__(self, args, **kwargs):
        super(RefinerPR, self).__init__()
        self.args = args
        self.encoder = Encoder(args)
        self.decoder = Decoder(args)
        self.num_chi_bins = self.decoder.chi_predictor.num_chi_bins
        self.bin_centers = self.decoder.chi_predictor.bin_centers
        self._init_params()

        self.X_to_chi = sidechain.ChiAngles_Graph()
        self.chi_to_X = sidechain.SideChainBuilder_Graph()
        self.rmsd_sidechain = sidechain.LossSideChainRMSD()
        self.clash_sidechain = sidechain.LossSidechainClashes()
        self.noise_perturb = DiffusionChainCov(
            noise_schedule='log_snr',
            beta_min=0.2,
            beta_max=70,
            log_snr_range=[-7.0, 13.5],
            covariance_model='globular',
            complex_scaling=True,
        )
        self.loss_diffusion = ReconstructionLosses(
            diffusion=self.noise_perturb, rmsd_method='symeig', loss_scale=10.0
        )

    def _init_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p, gain=0.9)

    def _expand(self, C, S, batch_size, num_residues, batch_id, device):
        num_nodes = scatter_sum(torch.ones_like(batch_id), batch_id)
        row = batch_id
        col = torch.cat([torch.arange(0,n) for n in num_nodes]).to(device)
        cc = torch.zeros(batch_size, num_residues, device=device, dtype=C.dtype)
        cc[row, col] = C
        ss = torch.zeros(batch_size, num_residues, device=device, dtype=S.dtype)
        ss[row, col] = S
        return cc, ss
    
    def compute_loss(self, preds, chi_probs, chi, X, C, S, batch_id):
        all_results = {
            "batch_elbo": 0,
            "batch_global_mse": 0,
            "batch_fragment_mse": 0,
            "batch_pair_mse": 0,
            "batch_neighborhood_mse": 0,
            "batch_distance_mse": 0,
        }
        
        L = len(preds)

        if X.shape[-2] > 4:
            gt_chi, _ = self.X_to_chi(X, C, S)
            chi_mse_loss = F.mse_loss(chi, gt_chi)
        else:
            chi_mse_loss = 0.

        X, mask = GraphTransform.sparse2dense_node(X, batch_id)
        t = torch.ones(X.shape[0], device=X.device)
        C, S = self._expand(C, S, X.shape[0], X.shape[1], batch_id, X.device)

        for i, pred in enumerate(preds):
            X_pred, _ = GraphTransform.sparse2dense_node(pred, batch_id)
            # backbone
            results = self.loss_diffusion(X_pred[:, :, :4], X[:, :, :4], mask, t)
            
            for key in all_results.keys():
                all_results[key] += results[key] / L

        all_results['batch_chi_mse'] = chi_mse_loss   

        X_pred, _ = GraphTransform.sparse2dense_node(preds[-1], batch_id)

        if X.shape[-2] > 4:
            if self.args.model_type == 5:
                all_results['batch_sidechain_rmsd'] = self.rmsd_sidechain(X_pred, X, C, S).mean()
            elif self.args.model_type == 6:
                all_results['batch_sidechain_clash'] = self.clash_sidechain(X_pred, C, S).mean()
            elif self.args.model_type == 7:
                all_results['batch_sidechain_rmsd'] = self.rmsd_sidechain(X_pred, X, C, S).mean()
                all_results['batch_sidechain_clash'] = self.clash_sidechain(X_pred, C, S).mean()        

        loss = sum([all_results[key] for key in all_results.keys() if key != 'batch_elbo'])
        all_results['loss'] = loss
        return all_results

    def forward(self, X, C, S, node_idx, edge_idx, batch_id):
        V, E, T, T_ts, batch_id_extend, edge_idx_extend = GeoFeaturizer.from_X_to_features(X, edge_idx, batch_id)
        h_V = self.encoder(edge_idx, batch_id, V, E, T_ts, batch_id_extend, edge_idx_extend)
        bb_preds, chi_probs, chi = self.decoder(h_V, torch.rand_like(X[:, :4]), C, S, node_idx, batch_id, dec_topk=self.args.dec_topk)
        return bb_preds, chi_probs, chi