import torch
from torch import nn
import copy
import slide_encoder
from graph_models import ATT_learner, GraphEncoder
import numpy as np



def SNN_Block(dim1, dim2, dropout=0.25):
    r"""
    Multilayer Reception Block w/ Self-Normalization (Linear + ELU + Alpha Dropout)

    args:
        dim1 (int): Dimension of input features
        dim2 (int): Dimension of output features
        dropout (float): Dropout rate
    """
    import torch.nn as nn

    return nn.Sequential(
            nn.Linear(dim1, dim2),
            nn.ELU(),
            nn.AlphaDropout(p=dropout, inplace=False))


class Adaptor(nn.Module):
    def __init__(self, omic_sizes):
        super(Adaptor, self).__init__()
        hidden = [768, 768]
        sig_networks = []
        for input_dim in omic_sizes:
            fc_omic = [SNN_Block(dim1=input_dim, dim2=hidden[0])]
            for i, _ in enumerate(hidden[1:]):
                fc_omic.append(SNN_Block(dim1=hidden[i], dim2=hidden[i+1], dropout=0.25))
            sig_networks.append(nn.Sequential(*fc_omic))
        self.sig_networks = nn.ModuleList(sig_networks) 

    def forward(self, x_omic):
        x_omic = [torch.stack(x_omic[i]) for i in range(len(x_omic))]
        h_omic = [self.sig_networks[idx].forward(sig_feat.float()) for idx, sig_feat in enumerate(x_omic)]
        x = torch.mean(torch.stack(h_omic, dim=0), dim=0)
        return x



class CL_model(nn.Module):
    def __init__(self, args, omic_sizes, byol_hidden_dim=512, dropout=.5, byol_pred_dim=256, class_num=3, momentum=0.999, slide_encoder_pth='', init_graph=''):
        super(CL_model, self).__init__()
        node_name = []
        node_fused_features_list = []
        graph_dict = torch.load(init_graph, map_location='cpu')
        WSI_features = graph_dict['WSI_features']
        WSI_RNA_features = graph_dict['WSI_RNA_features']
        for key, value in WSI_features.items():
            node_name.append(key)
            node_fused_features_list.append(WSI_RNA_features[key])
        node_fused_features_np = np.asarray(node_fused_features_list)
        init_fused_feature = torch.tensor(node_fused_features_np, dtype=torch.float32)
        self.register_buffer('fused_buffer', init_fused_feature.clone())  # Online 
        self.node_id_to_idx = {node_id : idx for idx, node_id in enumerate(node_name)}

        self.momentum = momentum
        self.online_slide_encoder = slide_encoder.create_model(pretrained=slide_encoder_pth, model_arch=args.model_arch, in_chans=args.patch_in_chans, global_pool=False)

        self.online_omic_generator = SNN_Block(dim1=args.in_chans, dim2=args.in_chans)

        self.online_graph_learner = ATT_learner(nlayers=args.leaner_layers, isize=args.in_chans * 2, k=args.leaner_k, 
                                                i=6, dropedge_rate=args.dropedge_rate, sparse=args.sparse, 
                                                act=args.activation_learner)
        
        self.online_graph_encoder = GraphEncoder(nlayers=args.nlayers, in_dim=args.in_chans * 2, hidden_dim=args.hidden_dim,
                                                 emb_dim=args.rep_dim, dropout=args.dropout, sparse=args.sparse)
        self.norm = nn.LayerNorm(args.rep_dim)
        self.fc = nn.Linear(byol_pred_dim, class_num, bias=False)

        self.target_slide_encoder = slide_encoder.create_model(pretrained=slide_encoder_pth, model_arch=args.model_arch, in_chans=args.patch_in_chans, global_pool=False)
        self.target_omic_adaptor = Adaptor(omic_sizes)
        self.target_graph_learner = ATT_learner(nlayers=args.leaner_layers, isize=args.in_chans * 2, k=args.leaner_k, 
                                                i=6, dropedge_rate=args.dropedge_rate, sparse=args.sparse, 
                                                act=args.activation_learner)
        self.target_graph_encoder = GraphEncoder(nlayers=args.nlayers, in_dim=args.in_chans * 2, hidden_dim=args.hidden_dim,
                                                 emb_dim=args.rep_dim, dropout=args.dropout, sparse=args.sparse)

        for params in self.target_slide_encoder.parameters():
            params.requires_grad = False
        for params in self.target_omic_adaptor.parameters():
            params.requires_grad = False
        for params in self.target_graph_learner.parameters():
            params.requires_grad = False
        for params in self.target_graph_encoder.parameters():
            params.requires_grad = False
        
    @torch.no_grad()
    def _update_moving_average(self):
        for online_params, target_params in zip(self.online_slide_encoder.parameters(), self.target_slide_encoder.parameters()):
            target_params.data = target_params.data * self.momentum + online_params.data * (1 - self.momentum)

        for online_params, target_params in zip(self.online_graph_learner.parameters(), self.target_graph_learner.parameters()):
            target_params.data = target_params.data * self.momentum + online_params.data * (1 - self.momentum)

        for online_params, target_params in zip(self.online_graph_encoder.parameters(), self.target_graph_encoder.parameters()):
            target_params.data = target_params.data * self.momentum + online_params.data * (1 - self.momentum)

    @torch.no_grad()
    def update_node_features(self, node_ids, new_features):
        
        if isinstance(new_features, list): 
            new_features = torch.stack(new_features, dim=0)
        indices = torch.tensor([self.node_id_to_idx[nid] for nid in node_ids], dtype=torch.long, device=self.fused_buffer.device)
        self.fused_buffer[indices] = new_features

    def upgrade_graph(self, node_ids, graph_learner, new_features):
        node_features = self.fused_buffer.detach().clone()
        indices = torch.tensor([self.node_id_to_idx[nid] for nid in node_ids], dtype=torch.long, device=node_features.device)
        if isinstance(new_features, list): 
            new_features = torch.stack(new_features, dim=0)
        node_features[indices] = new_features
        learn_embedding = graph_learner(node_features)
        A = graph_learner.graph_process(learn_embedding)
        A_subset = A[indices]
        return A, A_subset
    
    def upgrade_graph_feature(self, node_ids, Adj, graph_encoder, new_features):
        node_features = self.fused_buffer.detach().clone()
        indices = torch.tensor([self.node_id_to_idx[nid] for nid in node_ids], dtype=torch.long, device=node_features.device)
        if isinstance(new_features, list): 
            new_features = torch.stack(new_features, dim=0)
        node_features[indices] = new_features
        graph_features = graph_encoder(node_features, Adj)
        sub_graph_features = graph_features[indices]
        return graph_features, sub_graph_features

    def forward(self, view1, coords1, view2, coords2, rna_features, slide_ids):
        if self.training:
            x_1 = self.online_slide_encoder(view1, coords1)['cls_token']
            generate_omic = self.online_omic_generator(x_1)
            fused_features = torch.cat((generate_omic, x_1), dim=-1)
            A_1, sub_A_1 = self.upgrade_graph(slide_ids, self.online_graph_learner, fused_features)
            G_1, sub_G_1 = self.upgrade_graph_feature(slide_ids, A_1, self.online_graph_encoder, fused_features)
            x_1 = self.norm(sub_G_1)
            logits = self.fc(x_1)
            Y_hat = torch.topk(logits, 1, dim = -1)[1]
            hazards = torch.sigmoid(logits)
            S = torch.cumprod(1 - hazards, dim=1)


            with torch.no_grad():
                self._update_moving_average()
                x_2 = self.target_slide_encoder(view2, coords2)['cls_token']
                omic_features = self.target_omic_adaptor(rna_features)
                fused_features_2 = torch.cat((omic_features, x_2), dim=-1)
                A_2, sub_A_2 = self.upgrade_graph(slide_ids, self.target_graph_learner, fused_features_2)
                G_2, sub_G_2 = self.upgrade_graph_feature(slide_ids, A_2, self.target_graph_encoder, fused_features_2)
                self.update_node_features(slide_ids, fused_features_2)
            
        
            return fused_features, sub_A_1, sub_G_1, fused_features_2, sub_A_2, sub_G_2, logits, hazards, S, Y_hat

        else:
            with torch.no_grad():
                x_1 = self.online_slide_encoder(view1, coords1)['cls_token']
                generate_omic = self.online_omic_generator(x_1)
                fused_features = torch.cat((generate_omic, x_1), dim=-1)
                infere_graph = torch.cat((self.fused_buffer, fused_features), dim=0)
                learn_embedding = self.online_graph_learner(infere_graph)
                A = self.online_graph_learner.graph_process(learn_embedding)
                graph_features = self.online_graph_encoder(infere_graph, A)
                batch_size = view1.shape[0]
                sub_graph_features = graph_features[-batch_size:]
                x_1 = self.norm(sub_graph_features)
                
                logits = self.fc(x_1) 
                Y_hat = torch.topk(logits, 1, dim = -1)[1]
                hazards = torch.sigmoid(logits)
                S = torch.cumprod(1 - hazards, dim=1)
                return logits, hazards, S, Y_hat



        

