import torch
import copy
import numpy as np
from typing import List
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R

from flowdock.dataset.complex_dataclasses import Ligand, ComplexBatch
from flowdock.dataset.pdbbind import PDBBind, apply_random_rotation_inplace
from flowdock.utils.transforms import apply_tor_changes_to_pos, get_bond_properties_for_angles, get_torsion_angles
from flowdock.utils.spyrmsd import compute_all_isomorphisms, get_symmetry_rmsd_with_isomorphisms


def randomize_ligand_all(ligand: Ligand, complex_isomorphism=None):
    randnum = np.random.rand()
    if randnum < 0.25:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, noise_scale=1.)
    elif randnum < 0.5:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, noise_scale=3.)
    elif randnum < 0.65:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, random_rotations=True)
    elif randnum < 0.8:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, shift_scale=5., noise_scale=3.)
    elif randnum < 0.9:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, shift_scale=15., noise_scale=5.)
    else:
        randomize_ligand_rot_tor_only(ligand, complex_isomorphism=complex_isomorphism, shift_scale=30., random_rotations=True)


def randomize_ligand_rot_tor_only(ligand: Ligand, complex_isomorphism=None, shift_scale: float = 1.0, noise_scale: float = 1.0, 
                                  random_rotations: bool = False):
    """
    Randomize the rotation and torsion of a ligand.

    Parameters:
    ----------
    ligand : Ligand
        The input ligand to be randomized.
    Returns:
    -------
    None
    """
    pos = np.copy(ligand.orig_pos)
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]

    # Tr:
    pos_mean = pos.mean(axis=0).reshape(1, 3)
    small_noise = np.random.rand()
    tr_shift = np.random.normal(0, 0.5 * small_noise * shift_scale, 3).astype(np.float32).reshape(1, 3)
    tr = pos_mean + tr_shift

    if random_rotations:
        rot = R.random().as_matrix().astype(np.float32)
        torsion_updates = np.random.uniform(-np.pi, np.pi, num_rotatable_bonds).astype(np.float32)
    else:
        # Generate random rotation matrix close to identity
        # Sample small angles around 0 for each axis
        angle_std = noise_scale * np.pi / (-54 * small_noise + 60)
        angles = np.random.normal(0, angle_std, 3).astype(np.float32) # std of 20 degrees
        # Create rotation matrices for each axis
        Rx = R.from_euler('x', angles[0]).as_matrix()
        Ry = R.from_euler('y', angles[1]).as_matrix()
        Rz = R.from_euler('z', angles[2]).as_matrix()
        # Combine rotations
        rot = (Rx @ Ry @ Rz).astype(np.float32)
        torsion_updates = np.random.normal(0, angle_std, num_rotatable_bonds).astype(np.float32) # std of 10 degrees

    pos = (pos - pos_mean) @ rot.T + tr
    pos, _ = apply_tor_changes_to_pos(pos, ligand.rotatable_bonds, ligand.mask_rotate, 
                                   torsion_updates, is_reverse_order=True)

    ligand.init_tr = tr.reshape(1, 3)
    ligand.final_tr = pos_mean.reshape(1, 3)
    ligand.final_rot = (rot.T)[None, :, :]
    ligand.final_tor = -torsion_updates
    ligand.pos = pos
    
    # set RMSD
    try:
        ligand.rmsd = torch.tensor([get_symmetry_rmsd_with_isomorphisms(ligand.orig_pos, pos, complex_isomorphism)]).float()
        # real_rmsd = torch.tensor([np.sqrt(((pos - ligand.orig_pos) ** 2).sum(axis=1).sum(axis=0) / ligand.pos.shape[0])]).float()
        # print('Successfully computed symmetry RMSD', ligand.rmsd, real_rmsd)
    except Exception as e:
        ligand.rmsd = torch.tensor([np.sqrt(((pos - ligand.orig_pos) ** 2).sum(axis=1).sum(axis=0) / ligand.pos.shape[0])]).float()
    ligand.t = torch.ones(1, dtype=torch.float32)


def init_ligand_for_scoring(ligand: Ligand):
    """
    Initialize fields without randomization the position.

    Parameters:
    ----------
    ligand : Ligand
        The input ligand to be randomized.

    Returns:
    -------
    None
    """
    pos = np.copy(ligand.pos)
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]

    tr = pos.mean(axis=0).reshape(1, 3)
    rot = np.eye(3, dtype=np.float32)
    torsion_updates = np.zeros(num_rotatable_bonds).astype(np.float32)

    ligand.init_tr = tr.reshape(1, 3)
    ligand.final_rot = (rot.T)[None, :, :]
    ligand.final_tor = -torsion_updates
    ligand.t = torch.ones(1, dtype=torch.float32)
    ligand.pos = np.copy(pos)
    return ligand


class PDBBindForScoring(PDBBind):
    def get_random_complex(self, complex, complex_isomorphism=None):
        # 1. Rotate complex
        apply_random_rotation_inplace(complex)

        # 2. Add random noise to protein residues and ligand atom positions
        complex.protein.pos += np.random.normal(0, self.std_protein_pos, complex.protein.pos.shape)
        complex.ligand.pos += np.random.normal(0, self.std_lig_pos, complex.ligand.pos.shape)

        # 3. Shift to the protein center
        complex.shift_to_protein_center()

        # 4. Compute ligand gt values
        complex.set_ground_truth_values()

        randomize_ligand_all(complex.ligand, complex_isomorphism)        

        # 7. Shift to the protein center
        complex.shift_to_protein_center()
        complex.ligand.orig_pos -= complex.original_pocket_center

        # Randomly mask ligand atoms
        complex.randomly_mask_complex(ligand_mask_ratio=self.ligand_mask_ratio, protein_mask_ratio=self.protein_mask_ratio)

        # 6. Randomly sample neighbors of rotatable bonds
        complex.ligand.randomly_sample_neighbors_of_rotatable_bonds()
        if complex.ligand.rotatable_bonds_ext.start.shape[0] > 0:
            bond_properties_for_angles = get_bond_properties_for_angles(complex.ligand.rotatable_bonds_ext)
            complex.ligand.rotatable_bonds_ext.angles = get_torsion_angles(np.copy(complex.ligand.pos), 
                                                                           bond_atoms_for_angles=bond_properties_for_angles)

    def __getitem__(self, idx):
        complex = self.__get_nonrand_item__(idx)
        self.get_random_complex(complex)
        return complex
    

class PDBBindForRanking(PDBBindForScoring):
    def __init__(self, **kwargs):
        # Pop arguments specific to this class
        batch_size = kwargs.pop('batch_size', None)
        data_collator = kwargs.pop('data_collator', None)
        
        super().__init__(**kwargs)
        self.mol2isomorphisms = {complex.name: compute_all_isomorphisms(complex.ligand.orig_mol) for complex in
                                 tqdm(self.complexes, desc='Computing isomorphisms')}
        
        # Store class-specific attributes
        self.batch_size = batch_size
        self.data_collator = data_collator

    def __getitem__(self, idx):
        batch_complexes = []
        for i in range(self.batch_size):
            complex = self.__get_nonrand_item__(idx)
            complex_isomorphism = self.mol2isomorphisms.get(complex.name)
            self.get_random_complex(complex, complex_isomorphism)
            batch_complexes.append(complex)
        return self.data_collator(batch_complexes)


class PDBBindForScoringInferenceMixin:
    def __init__(self, **kwargs):
        # Pop arguments specific to this class
        predicted_complex_positions_path = kwargs.pop('predicted_complex_positions_path', None)
        
        kwargs['is_train_dataset'] = False
        # Call parent class initialization with remaining kwargs
        super().__init__(**kwargs)
        self.predicted_complex_positions = np.load(predicted_complex_positions_path, allow_pickle=True)[0]
        self.predicted_complex_positions = [(name, f'{name}_{i}', sample['transformed_orig'], sample.get('symm_rmsd', 0)) 
                                            for name, preds_list in self.predicted_complex_positions.items() 
                                            for i, sample in enumerate(preds_list)]
        self.name2index = {complex.name: idx for idx, complex in enumerate(self.complexes)}

        # TODO remove
        # self.predicted_complex_positions = self.predicted_complex_positions * 10

    def __len__(self):
        return len(self.predicted_complex_positions)

    def __getitem__(self, idx):
        uid, uid_full, lig_pos, rmsd = self.predicted_complex_positions[idx]
        try:
            complex = self.__get_nonrand_item__(self.name2index[uid])
        except Exception as e:
            complex = self.__get_nonrand_item__(self.name2index[uid.split('_conf')[0]])
        complex.ligand.pos = np.copy(lig_pos[:len(complex.ligand.pos)])

        # 1. Rotate complex
        apply_random_rotation_inplace(complex)

        # 2. Add random noise to protein residues and ligand atom positions
        complex.protein.pos += np.random.normal(0, self.std_protein_pos, complex.protein.pos.shape)
        complex.ligand.pos += np.random.normal(0, self.std_lig_pos, complex.ligand.pos.shape)

        # 3. Shift to the protein center
        complex.shift_to_protein_center()

        # 4. Compute ligand gt values
        complex.set_ground_truth_values()
        
        # 7. Shift to the protein center
        complex.shift_to_protein_center()
        complex.ligand.orig_pos -= complex.original_pocket_center
            
        # Randomly mask protein and ligand atoms
        complex.randomly_mask_complex(ligand_mask_ratio=self.ligand_mask_ratio, protein_mask_ratio=self.protein_mask_ratio)
    
        # 6. Randomly sample neighbors of rotatable bonds
        complex.ligand.randomly_sample_neighbors_of_rotatable_bonds()
        if complex.ligand.rotatable_bonds_ext.start.shape[0] > 0:
            bond_properties_for_angles = get_bond_properties_for_angles(complex.ligand.rotatable_bonds_ext)
            complex.ligand.rotatable_bonds_ext.angles = get_torsion_angles(np.copy(complex.ligand.pos), 
                                                                           bond_atoms_for_angles=bond_properties_for_angles)

        # save rmsd in complex
        init_ligand_for_scoring(complex.ligand)
        complex.original_augm_rot = np.eye(3, dtype=np.float32)[None, :, :]
        complex.ligand.rmsd = torch.tensor([rmsd]).float()
        complex.name = uid_full
        return complex


class PDBBindForScoringInference(PDBBindForScoringInferenceMixin, PDBBind):
    pass


def dummy_ranking_collate_fn(batch: List[ComplexBatch]) -> ComplexBatch:
    return batch[0]
