# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# TODO: https://code.roche.com/prescient-design/prescient/-/blob/master/nebula/src/nebula/sample_from_seed_main.py#L227

import copy
import os
import pathlib
from typing import Literal, Optional

import torch
import torch.nn as nn
import wget

from torchdrug.layers import GraphConstruction
from torchdrug.layers.geometry import (
    SequentialEdge,
    SpatialEdge,
    KNNEdge,
    AlphaCarbonNode,
)


def test_cc(protein):
    residue_type = protein.residue_type.tolist()
    cc_id = protein.connected_component_id.tolist()
    sequence = []
    restypes = []
    for i in range(protein.num_residue):
        if i > 0 and cc_id[i] > cc_id[i - 1]:
            sequence.append(".")
        #print(i, cc_id[i], residue_type[i])
        sequence.append(protein.id2residue_symbol[residue_type[i]])
        restypes.append(residue_type[i])
    return "".join(sequence), restypes
    

class GearNetStructureModel(nn.Module):
    def __init__(
        self,
        pre_trained_name: Literal["angle", "dihedral", "attr", "distance", "mc"]="mc",
        checkpoint_dir: str =str(pathlib.Path(__file__).parent / "checkpoints"),
        device: str = "cuda",
        load: bool = True,
        concat_hidden: bool = True,
        max_seq_len: int = 301,
        by_chain: bool = False,
        max_heavy_len: int = 151,
        max_light_len: int = 150,
        num_layers: int = 6,
        checkpoint_file: Optional[str]= None
    ):
        """
        :param: pre_trained_name: name of the pretrained model (i.e. pre-training strategy)
        :param: checkpoint_dir: directory to store downloaded weights
        """
        super().__init__()
        from torchdrug.models import GeometryAwareRelationalGraphNeuralNetwork

        self.pre_trained_name = pre_trained_name
        self.checkpoint_dir = pathlib.Path(checkpoint_dir)
        self.max_seq_len = max_seq_len
        self.by_chain = by_chain
        self.max_heavy_len = max_heavy_len
        self.max_light_len = max_light_len
        self.num_layers = num_layers
        if num_layers > 1:
            concat_hidden = True
        self.concat_hidden = concat_hidden

        if load and (checkpoint_file is None):
            checkpoint = torch.load(self._weight_selector(), map_location=device)
        
        if load and checkpoint_file is not None:
            checkpoint = torch.load(checkpoint_file, map_location=device)
            print(f'loading checkpoint from file {checkpoint_file}')

        encoder_hparams = {
            "hidden_dims": [512, 512, 512, 512, 512, 512],
            "input_dim": 21,
            "batch_norm": True,
            "concat_hidden": concat_hidden,
            "short_cut": True,
            "readout": "sum",
            "num_relation": 7,
            "edge_input_dim": 59,
            "num_angle_bin": 8,
        }

        self.encoder = GeometryAwareRelationalGraphNeuralNetwork(**encoder_hparams)
        if load:
            self.encoder.load_state_dict(checkpoint)
        self.encoder = self.encoder.to(device)
        self.n_outputs = 512 * 6 if concat_hidden else 512

        self.transform = GraphConstruction(
            node_layers=[AlphaCarbonNode()],
            edge_layers=[
                SequentialEdge(max_distance=2),
                SpatialEdge(radius=10.0),
                KNNEdge(k=10, min_distance=5),
            ],
            edge_feature="gearnet",
        ).to(device)

    def _weight_selector(self) -> pathlib.Path:
        """
        Downloads pretrained weights if not found locally.

        Returns path to weights.
        """
        fname = f"{self.pre_trained_name}_gearnet_edge.pt"

        if self.pre_trained_name not in ["angle", "dihedral", "attr", "distance", "mc"]:
            raise ValueError(f"Pretrained model {self.pre_trained_name} doesn't exist.")
        else:
            # Check for downloaded weights
            url = f"https://zenodo.org/record/7593637/files/{self.pre_trained_name}_gearnet_edge.pth?download=1"
            if not os.path.exists(self.checkpoint_dir / fname):
                if not os.path.exists(self.checkpoint_dir):
                    os.makedirs(self.checkpoint_dir)
                wget.download(url, out=str(self.checkpoint_dir / fname))

        print(f"Loaded checkpoint {self.checkpoint_dir / fname}")
        return self.checkpoint_dir / fname

    def forward(self, graph) -> torch.Tensor:
        self.encoder = self.encoder.eval()
        with torch.no_grad():
            graph = self.transform(graph)
            x = self.encoder(graph, graph.node_feature.float())
            x = x['node_feature']
            if self.concat_hidden:
                x = x[:, -self.num_layers*512:]
            num_nodes = [g.num_node.item() for g in graph.unpack()]
            featlist = torch.split(x, num_nodes, dim=0)
            if self.by_chain:
                heavy_chain_lens = [g.chain_id[g.chain_id==0].shape[0] for g in graph.unpack()]
                feats_heavy = torch.stack([torch.nn.functional.pad(feat[:h_len].permute(1, 0), 
                                                                   (0, self.max_heavy_len-h_len), value=0).permute(1,0) 
                                 for feat, h_len in zip(featlist, heavy_chain_lens)], dim=0)
                feats_light = torch.stack([torch.nn.functional.pad(feat[h_len:].permute(1, 0), 
                                                                   (0, self.max_light_len - (num_node - h_len)), 
                                                                   value=0).permute(1,0) 
                                 for feat, h_len, num_node in zip(featlist, heavy_chain_lens, num_nodes)], dim=0)
                feats = torch.cat([feats_heavy, feats_light], dim=1)
                #print(feats[0, :, :2])
            else:
                feats = torch.stack([torch.nn.functional.pad(feat.permute(1,0), (0, self.max_seq_len-feat.shape[0]), value=0).permute(1,0) 
                                 for feat in featlist], dim=0)

        return feats

    def train(self, mode=False):
        """
        Override the default train() to freeze the BN parameters
        """
        super().train(mode)
        self.freeze_bn()

    def freeze_bn(self):
        for m in self.encoder.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()

