import copy
import numpy as np
import torch

from time import time
# from torch.profiler import profile, record_function, ProfilerActivity

import pytorch_lightning as pl

from torch_scatter import scatter_mean, scatter_sum

from core.config.config import Config
from core.models.pif4sbdd import PIF4SBDDScoreModel

import core.evaluation.utils.atom_num as atom_num
import core.utils.transforms as trans

from core.utils.train import get_optimizer, get_scheduler


def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode="protein"):
    if mode == "none":
        offset = 0.0
        pass
    elif mode == "protein":
        offset = scatter_mean(protein_pos, batch_protein, dim=0)
        protein_pos = protein_pos - offset[batch_protein]
        ligand_pos = ligand_pos - offset[batch_ligand]
    elif mode == "ligand":
        offset = scatter_mean(ligand_pos, batch_ligand, dim=0)
        protein_pos = protein_pos - offset[batch_protein]
        ligand_pos = ligand_pos - offset[batch_ligand]
    else:
        raise NotImplementedError
    return protein_pos, ligand_pos, offset


class SBDDTrainLoop(pl.LightningModule):
    def __init__(self, config: Config):
        super().__init__()
        self.cfg = config
        self.dynamics = PIF4SBDDScoreModel(**self.cfg.dynamics.todict())
        # [ time, h_t, pos_t, edge_index]
        self.train_losses = []
        self.save_hyperparameters(self.cfg.todict())
        self.time_records = np.zeros(6)
        self.log_time = False


    # def get_time_scheduler(self):
    #     w_pos = torch.softmax(self.time_scheduler_pos, dim=0)
    #     w_pos = torch.cumsum(w_pos, dim=0)
    #     w_pos = w_pos / w_pos[-1]

    #     # w_type = torch.softmax(self.time_scheduler_type, dim=0)
    #     # w_type = torch.cumsum(w_type, dim=0)
    #     # w_type = w_type / w_type[-1]
    #     # return w_pos, w_type
    #     return w_pos

    # def sample_t_from_scheduler(self, batch_ligand, num_graphs):
    #     # scheduler_pos,  scheduler_type= self.get_time_scheduler()  # shape: [S]
    #     scheduler_pos= self.get_time_scheduler()  # shape: [S]

    #     num_steps = scheduler_pos.shape[0]

    #     indices = torch.randint(0, len(scheduler_pos), (num_graphs,), device=self.device)

    #     # 每个 graph 对应一个时间
    #     t_graph_pos = scheduler_pos[indices]  # shape: [num_graphs]
    #     t_pos = t_graph_pos.index_select(0, batch_ligand).unsqueeze(-1)  # shape: [N_ligand, 1]

    #     # t_graph_type = scheduler_type[indices]  # shape: [num_graphs]
    #     # t_type = t_graph_type.index_select(0, batch_ligand).unsqueeze(-1)  # shape: [N_ligand, 1]

    #     # return t_pos, t_type
    #     return t_pos

    def forward(self, x):
        pass

    def training_step(self, batch, batch_idx):
        t1 = time()
        protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand = (
            getattr(batch, "protein_pos", None),
            batch.protein_atom_feature.float() if hasattr(batch, "protein_atom_feature") else None,
            getattr(batch, "protein_element_batch", None),
            batch.ligand_pos,
            batch.ligand_atom_feature_full,
            batch.ligand_element_batch,
        )  # get the data from the batch
        # batch is a data object
        # protein_pos: [N_pro,3]
        # protein_v: [N_pro,27]
        # batch_protein: [N_pro]
        # ligand_pos: [N_lig,3]
        # ligand_v: [N_lig,13]
        # protein_element_batch: [N_protein]


        # if self.current_epoch < 10:
        #     protein_pos = None
        


        t2 = time()
        num_graphs = batch_ligand.max().item() + 1

        if protein_pos is not None:
            with torch.no_grad():
                if self.cfg.train.pos_noise_std > 0:
                    # add noise to protein_pos
                    protein_noise = torch.randn_like(protein_pos) * self.cfg.train.pos_noise_std
                    protein_pos = batch.protein_pos + protein_noise
                # random rotation as data aug
                if self.cfg.train.random_rot:
                    M = np.random.randn(3, 3)
                    Q, __ = np.linalg.qr(M)
                    Q = torch.from_numpy(Q.astype(np.float32)).to(ligand_pos.device)
                    protein_pos = protein_pos @ Q
                    ligand_pos = ligand_pos @ Q

            # !!!!!
            protein_pos, ligand_pos, _ = center_pos(
                protein_pos,
                ligand_pos,
                batch_protein,
                batch_ligand,
                mode=self.cfg.dynamics.center_pos_mode,
            )  # TODO: ugly
        else:
            _, ligand_pos, _ = center_pos(
                ligand_pos,
                ligand_pos,
                batch_ligand,
                batch_ligand,
                mode=self.cfg.dynamics.center_pos_mode,
            )
            # perturb_offset = torch.rand(1) * self.cfg.data.normalizer_dict.pos
            # perturb_offset = perturb_offset.to(ligand_pos.device)
            # ligand_pos = ligand_pos + perturb_offset


        t3 = time()

        # # 采样一个随机的t
        # t = torch.rand(
        #     [num_graphs, 1], dtype=ligand_pos.dtype, device=ligand_pos.device
        # ).index_select(
        #     0, batch_ligand
        # )  # different t for different molecules.  [N_ligand, 1]


        # # 采样两个随机的t
        # t_raw = torch.rand(
        #     [num_graphs, 2], dtype=ligand_pos.dtype, device=ligand_pos.device
        # )  # [G, 2]
        # # 对每个 graph 内排序，确保 t1 <= t2
        # t_sorted, _ = torch.sort(t_raw, dim=1)   # [G, 2]
        # # 按 batch_ligand 展开到每个原子的维度
        # t1 = t_sorted[:, 0].index_select(0, batch_ligand).unsqueeze(-1)   # [N_lig, 1]
        # t2 = t_sorted[:, 1].index_select(0, batch_ligand).unsqueeze(-1)   # [N_lig, 1]


        # 采样一个随机的t1和固定间隔的t2
        t1 = torch.rand(
            [num_graphs, 1], dtype=ligand_pos.dtype, device=ligand_pos.device
        ).index_select(
            0, batch_ligand
        )  # different t for different molecules.  [N_ligand, 1]
        t2 = t1


        # # 采样一个随机的t1和随训练增加间隔的t2，loss5
        # delta_t = 1 / (self.trainer.estimated_stepping_batches - 1) * self.global_step # loss8
        # # delta_t = np.random.rand() * 0.1  # loss8_ref
        # # delta_t = np.clip(delta_t, a_min=None, a_max = 1.0)
        # delta_t = np.clip(delta_t, a_min=None, a_max = 0.1)  # loss8,t_next=t2
        # t1 = torch.rand(
        #     [num_graphs, 1], dtype=ligand_pos.dtype, device=ligand_pos.device
        # ).index_select(
        #     0, batch_ligand
        # ) * (1 - delta_t)  # different t for different molecules.  [N_ligand, 1]
        # t2 = t1 + delta_t
        # print(f't1_max={torch.max(t1)},delta_t={delta_t}')


        if not self.cfg.dynamics.use_discrete_t and not self.cfg.dynamics.destination_prediction:
            # t = torch.clamp(t, min=self.dynamics.t_min)  # clamp t to [t_min,1]
            t1 = torch.clamp(t1, min=self.dynamics.t_min, max=1.0)  # clamp t to [t_min,1]
            t2 = torch.clamp(t2, min=self.dynamics.t_min, max=1.0)  # clamp t to [t_min,1]


        t4 = time()
        try:
            c_loss, d_loss, e_loss = self.dynamics.loss_one_step(
                # t,
                t1,
                t2,
                protein_pos=protein_pos,
                protein_v=protein_v,
                batch_protein=batch_protein,
                ligand_pos=ligand_pos,
                ligand_v=ligand_v,
                batch_ligand=batch_ligand,
                ligand_bond_type=getattr(batch, "ligand_fc_bond_type", None),
                ligand_bond_index=getattr(batch, "ligand_fc_bond_index", None),
                batch_ligand_bond=getattr(batch, "ligand_fc_bond_type_batch", None),
            )
            loss = c_loss + d_loss + e_loss
        except RuntimeError as e:
            if 'CUDA out of memory' in str(e):
                print(f"Skipping batch {batch_idx} due to CUDA OOM.")
                for p in self.dynamics.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                return None
            else:
                raise e

        t5 = time()
        self.log_dict(
            {
                'lr': self.get_last_lr(),
                'loss': loss.item(),
            },
            on_step=True,
            prog_bar=True,
            batch_size=self.cfg.train.batch_size,
        )
        self.log_dict(
            {
                'loss_pos': c_loss.item(), 
                'loss_type': d_loss.item(),
                'loss_bond': e_loss.item(),
            },
            on_step=True,
            prog_bar=False,
            batch_size=self.cfg.train.batch_size,
        )

        # check if loss is finite, skip update if not
        if not torch.isfinite(loss):
            return None
        self.train_losses.append(loss.clone().detach().cpu())

        t0 = time()

        if self.log_time:
            self.time_records = np.vstack((self.time_records, [t0, t1, t2, t3, t4, t5]))
            print(f'step total time: {self.time_records[-1, 0] - self.time_records[-1, 1]}, batch size: {num_graphs}')
            print(f'\tpl call & data access: {self.time_records[-1, 1] - self.time_records[-2, 0]}')
            print(f'\tunwrap data: {self.time_records[-1, 2] - self.time_records[-1, 1]}')
            print(f'\tadd noise & center pos: {self.time_records[-1, 3] - self.time_records[-1, 2]}')
            print(f'\tsample t: {self.time_records[-1, 4] - self.time_records[-1, 3]}')
            print(f'\tget loss: {self.time_records[-1, 5] - self.time_records[-1, 4]}')
            print(f'\tlogging: {self.time_records[-1, 0] - self.time_records[-1, 5]}')
        return loss

    def validation_step(self, batch, batch_idx):
        out_data_list = self.shared_sampling_step(batch, batch_idx, sample_num_atoms='ref', desc=f'Val')
        return out_data_list
    
    def test_step(self, batch, batch_idx):
        # TODO change order, samples of the same pocket should be together, reduce protein loading
        out_data_list = []
        n_samples = self.cfg.evaluation.num_samples
        for _ in range(n_samples):
            batch_output = self.shared_sampling_step(batch, batch_idx, sample_num_atoms=self.cfg.evaluation.sample_num_atoms, 
                                                     desc=f'Test-{_}/{n_samples}')
            # for idx, item in enumerate(batch_output):
            out_data_list.append(batch_output)
                
        out_data_list_reorder = []
        for i in range(len(out_data_list[0])):  # n_batch
            for j in range(len(out_data_list)):  # n_samples
                out_data_list_reorder.append(out_data_list[j][i])
        return out_data_list_reorder

    def shared_sampling_step(self, batch, batch_idx, sample_num_atoms, desc=''):
        # here we need to sample the molecules in the validation step
        
        protein_pos, protein_v, batch_protein, ligand_pos, ligand_v, batch_ligand = (
            getattr(batch, "protein_pos", None),
            batch.protein_atom_feature.float() if hasattr(batch, "protein_atom_feature") else None,
            getattr(batch, "protein_element_batch", None),
            batch.ligand_pos,
            batch.ligand_atom_feature_full,
            batch.ligand_element_batch,
        )


        # if self.current_epoch < 10:
        #     protein_pos = None


        gen_flag_lig = getattr(batch, 'gen_flag_lig', None)
        
        num_graphs = batch_ligand.max().item() + 1  # B
        n_nodes = batch_ligand.size(0)  # N_lig
        assert num_graphs == len(batch), f"num_graphs: {num_graphs} != len(batch): {len(batch)}"


        # move protein center to origin & ligand correspondingly
        if protein_pos is not None:
            protein_pos, ligand_pos, offset = center_pos(
                protein_pos,
                ligand_pos,
                batch_protein,
                batch_ligand,
                mode=self.cfg.dynamics.center_pos_mode,
            )
        else:
            _, ligand_pos, offset = center_pos(
                torch.zeros_like(ligand_pos),
                ligand_pos,
                batch_ligand,
                batch_ligand,
                mode=self.cfg.dynamics.center_pos_mode,
            )


          # determine the number of atoms in the ligand
        if sample_num_atoms == 'prior':
            ligand_num_atoms = []
            ligand_fc_bond_indices = []
            ligand_num_edges = []
            for data_id in range(len(batch)):
                data = batch[data_id]
                if protein_pos is not None:
                    pocket_size = atom_num.get_space_size(data.protein_pos.detach().cpu().numpy() * self.cfg.data.normalizer_dict.pos)
                else:
                    raise NotImplementedError("No protein pos for prior sampling")
                n_atoms = atom_num.sample_atom_num(pocket_size).astype(int)
                ligand_num_atoms.append(n_atoms)

                # Add the computed bond index to the list
                full_dst = torch.repeat_interleave(torch.arange(n_atoms), n_atoms)
                full_src = torch.arange(n_atoms).repeat(n_atoms)
                mask = full_dst != full_src
                full_dst, full_src = full_dst[mask], full_src[mask]
                # Shift the indices to the correct position
                if len(ligand_num_atoms) > 1:
                    full_dst += sum(ligand_num_atoms[:-1])
                    full_src += sum(ligand_num_atoms[:-1])
                ligand_fc_bond_index = torch.stack([full_src, full_dst], dim=0)
                assert ligand_fc_bond_index.size(0) == 2 and ligand_fc_bond_index.size(1) == n_atoms * (n_atoms - 1)
                ligand_fc_bond_indices.append(ligand_fc_bond_index)
                ligand_num_edges.append(ligand_fc_bond_index.size(1))

            batch_ligand = torch.repeat_interleave(torch.arange(len(batch)), torch.tensor(ligand_num_atoms)).to(ligand_pos.device)
            ligand_num_atoms = torch.tensor(ligand_num_atoms, dtype=torch.long, device=ligand_pos.device)
            batch_ligand_bond = torch.repeat_interleave(torch.arange(len(batch)), torch.tensor(ligand_num_edges)).to(ligand_pos.device)
            ligand_fc_bond_index = torch.cat(ligand_fc_bond_indices, dim=1).to(ligand_pos.device).long()
            assert ligand_fc_bond_index.size(1) == sum(ligand_num_edges)

            if hasattr(batch, "ligand_fc_bond_index"):
                pass
            else:
                ligand_fc_bond_index = None
                batch_ligand_bond = None

        elif sample_num_atoms == 'ref':
            batch_ligand = batch.ligand_element_batch
            ligand_num_atoms = scatter_sum(torch.ones_like(batch_ligand), batch_ligand, dim=0).to(ligand_pos.device)
            if hasattr(batch, "ligand_fc_bond_index"):
                ligand_fc_bond_index = batch.ligand_fc_bond_index
                batch_ligand_bond = batch.ligand_fc_bond_type_batch
            else:
                ligand_fc_bond_index = None
                batch_ligand_bond = None
        else:
            raise ValueError(f"sample_num_atoms mode: {sample_num_atoms} not supported")
        ligand_cum_atoms = torch.cat([
            torch.tensor([0], dtype=torch.long, device=ligand_pos.device), 
            ligand_num_atoms.cumsum(dim=0)
        ])

        sample_chain = self.dynamics.sample(
            protein_pos=protein_pos,
            protein_v=protein_v,
            batch_protein=batch_protein,
            batch_ligand=batch_ligand,
            ligand_bond_index=ligand_fc_bond_index,
            batch_ligand_bond=batch_ligand_bond,
            sample_steps=self.cfg.evaluation.sample_steps,
            n_nodes=num_graphs,
            desc=desc,
            ligand_pos_ref=ligand_pos,
            ligand_v_ref=ligand_v,
            ligand_bond_index_ref=getattr(batch, "ligand_fc_bond_index", None),
            ligand_bond_type_ref=getattr(batch, "ligand_fc_bond_type", None),
            ligand_bond_batch_ref=getattr(batch, "ligand_fc_bond_type_batch", None),
            ligand_batch_ref=batch.ligand_element_batch,
            gen_flag_lig=gen_flag_lig,
        )

        out_data_list_total = []
        # for final in sample_chain:
        for final in [sample_chain[-1]]:
            # final = sample_chain[-1]  # mu_pos_final, k_final

            pred_pos, one_hot, pred_bond_pmf = final[0] + offset[batch_ligand], final[1], final[2]

            # along with normalizer
            pred_pos = pred_pos * torch.tensor(
                self.cfg.data.normalizer_dict.pos, dtype=torch.float32, device=ligand_pos.device
            )
            out_batch = copy.deepcopy(batch)

            if protein_pos is not None:
                out_batch.protein_pos = out_batch.protein_pos * torch.tensor(
                    self.cfg.data.normalizer_dict.pos, dtype=torch.float32, device=ligand_pos.device
                )

            pred_v = one_hot.argmax(dim=-1)
            # TODO: ugly, should be done in metrics.py (but needs a way to make it compatible with pyg batch)
            pred_atom_type = trans.get_atomic_number_from_index(
                pred_v, mode=self.cfg.data.transform.ligand_atom_mode
            ) # List[int]

            # for visualization
            atom_type = [trans.MAP_ATOM_TYPE_ONLY_TO_INDEX[i] for i in pred_atom_type]  # List[int]
            atom_type = torch.tensor(atom_type, dtype=torch.long, device=ligand_pos.device)  # [N_lig]

            # for reconstruction
            pred_aromatic = trans.is_aromatic_from_index(
                pred_v, mode=self.cfg.data.transform.ligand_atom_mode
            ) # List[bool]


            # for bond generation
            pred_bond = pred_bond_pmf.argmax(dim=-1)  # [N_lig * N_lig]
            ligand_bond_array = pred_bond.cpu().numpy()
            ligand_num_bonds = scatter_sum(torch.ones_like(batch_ligand_bond),
                                            batch_ligand_bond).tolist()
            cum_bonds = np.cumsum([0] + ligand_num_bonds)
            # remove the offset to get the bond index
            ligand_fc_bond_index = ligand_fc_bond_index - ligand_cum_atoms[batch_ligand_bond]
            ligand_bond_index_array = ligand_fc_bond_index.cpu().numpy()


            # add necessary dict to new pyg batch
            out_batch.x, out_batch.pos = atom_type, pred_pos
            out_batch.atom_type = torch.tensor(pred_atom_type, dtype=torch.long, device=ligand_pos.device)
            out_batch.is_aromatic = torch.tensor(pred_aromatic, dtype=torch.long, device=ligand_pos.device)
            # out_batch.mol = results

            _slice_dict = {
                "x": ligand_cum_atoms,
                "pos": ligand_cum_atoms,
                "atom_type": ligand_cum_atoms,
                "is_aromatic": ligand_cum_atoms,
                # "mol": out_batch._slice_dict["ligand_filename"],
            }

            _inc_dict = {
                "x": out_batch._inc_dict["ligand_element"], # [0] * B,
                "pos": out_batch._inc_dict["ligand_pos"],
                "atom_type": out_batch._inc_dict["ligand_element"],
                "is_aromatic": out_batch._inc_dict["ligand_element"],
                # "mol": out_batch._inc_dict["ligand_filename"],
            }


            # for bond generation
            out_batch.bond = pred_bond
            _slice_dict["bond"] = cum_bonds
            _inc_dict["bond"] = out_batch._inc_dict["ligand_fc_bond_type"]
            out_batch.bond_index = ligand_fc_bond_index
            _slice_dict["bond_index"] = cum_bonds
            _inc_dict["bond_index"] = out_batch._inc_dict["ligand_fc_bond_type"]


            out_batch._inc_dict.update(_inc_dict)
            out_batch._slice_dict.update(_slice_dict)
            out_data_list = out_batch.to_data_list()

            out_data_list_total += out_data_list
        return out_data_list_total

    def on_train_epoch_end(self) -> None:
        if len(self.train_losses) == 0:
            epoch_loss = 0
        else:
            epoch_loss = torch.stack([x for x in self.train_losses]).mean()
        print(f"epoch_loss: {epoch_loss}")
        self.log(
            "epoch_loss",
            epoch_loss,
            batch_size=self.cfg.train.batch_size,
        )
        self.train_losses = []

    def configure_optimizers(self):
        self.optim = get_optimizer(self.cfg.train.optimizer, self)
        self.scheduler, self.get_last_lr = get_scheduler(self.cfg.train, self.optim)

        return {
            'optimizer': self.optim, 
            'lr_scheduler': self.scheduler,
        }
