from dataclasses import dataclass, field
from typing import Dict, List, Optional

import lightning as L
import mdtraj as md
import torch
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler
from torch import Tensor
from torch_cluster import radius
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.pool import nearest

from src.modules.density_field import DensityFieldBase, RBFDensityField
from src.utils.constants import ATOM_ENCODING, ATOM_NUMBER
from src.utils.data_utils import random_rotation_matrix, rotate_point_cloud
from src.utils.torch_utils import sample_gaussians


@dataclass
class MDDataset(InMemoryDataset):
    pos: Tensor
    atoms: list
    n_atom_types: int
    n_points_sample: int = 500
    occupancy_radius: float = 0.01
    n_supernodes: int = 50
    supernodes_max_neighbours: int = 10
    supernodes_radius: float = 0.1
    pos_scale: float = 200
    atom_sampling_std: float = 0.01
    n_atom_samples: int = 10
    rand_rotation: bool = False
    atom_pos_in_query: bool = False
    n_atoms: int = field(init=False)
    n_additional_samples: int = field(init=False)
    occ: Tensor = field(init=False)
    atoms_number: Tensor = field(init=False)
    density_field: DensityFieldBase = field(default_factory=RBFDensityField)

    def __post_init__(self):
        super().__init__()
        self.atoms_tensor = torch.tensor([ATOM_ENCODING[atom_type] for atom_type in self.atoms])
        self.n_atoms = len(self.atoms_tensor)
        self.n_additional_samples = (
            self.n_points_sample - self.n_atoms - self.n_atom_samples * self.n_atoms
        )
        self.occ = F.one_hot(
            self.atoms_tensor.to(torch.long), num_classes=self.n_atom_types
        ).float()

        sum_of_atoms_number = torch.sum(torch.tensor([n for n in ATOM_NUMBER.values()]))
        self.atoms_number = torch.tensor([ATOM_NUMBER[a] for a in self.atoms])
        self.atoms_number /= sum_of_atoms_number

        assert self.n_atoms <= self.n_points_sample
        # assert self.n_supernodes <= self.n_atoms + self.n_atom_samples * self.n_atoms

    def __len__(self) -> int:
        return len(self.pos)

    def __getitem__(self, idx) -> Data:
        atom_pos = self.atom_pos(idx)
        query_pos = self.query_pos(atom_pos)
        query_field = self.query_field(query_pos, atom_pos)

        # Random points
        nearest_neighbour = nearest(query_pos, atom_pos)
        query_dist = torch.norm(query_pos - atom_pos[nearest_neighbour], dim=-1)
        query_occ = (
            self.atoms_tensor[nearest_neighbour] * (query_dist < self.occupancy_radius).float()
        )
        query_occ = F.one_hot(query_occ.to(torch.long), num_classes=self.n_atom_types).float()

        # The non-occupied points, are not encoded, only queried.
        enc_mask = query_occ[:, 0] == 0
        enc_pos = query_pos[enc_mask]
        enc_occ = query_occ[enc_mask]
        enc_field = query_field[enc_mask]

        # Create supernode index
        supernode_index = torch.randperm(len(enc_pos))[: self.n_supernodes]

        query_pos = query_pos * self.pos_scale
        enc_pos = enc_pos * self.pos_scale

        return Data(
            enc_pos=enc_pos,
            enc_occ=enc_occ,
            enc_field=enc_field,
            query_pos=query_pos,
            query_occ=query_occ,
            query_field=query_field,
            atom_pos=atom_pos.tolist(),
            atom_type=self.atoms_tensor.tolist(),
            num_atoms=self.n_atoms,
            supernode_index=supernode_index,
            num_nodes=len(enc_pos),
        )

    def atom_pos(self, idx):
        if self.rand_rotation:
            atom_pos = rotate_point_cloud(self.pos[idx], random_rotation_matrix())
        else:
            atom_pos = self.pos[idx]
        return atom_pos

    def query_pos(self, atom_pos):
        query_pos = torch.rand((self.n_additional_samples, atom_pos.shape[-1]))
        if self.n_atom_samples > 0:
            atom_dens_pos = sample_gaussians(
                points=atom_pos, std=self.atom_sampling_std, num_of_samples=self.n_atom_samples
            )
            query_pos = torch.cat([query_pos, atom_dens_pos])

        if self.atom_pos_in_query:
            query_pos = torch.cat([query_pos, atom_pos])

        return query_pos

    def query_field(self, query_pos, atom_pos):
        query_field = self.density_field.create_field(
            query_pos,
            atom_pos,
            self.atoms_number,
        ).squeeze()
        return query_field

    def super_node_edge_index(self):
        raise NotImplementedError


@dataclass
class MDDataModule(L.LightningDataModule):
    traj_path: str
    top_path: str
    particle_type_encoding: Dict[str, int] = field(default_factory=lambda: ATOM_ENCODING)
    n_points_sample: int = 500
    occupancy_radius: float = 0.01
    n_supernodes: int = 50
    supernodes_max_neighbours: int = 10
    supernodes_radius: float = 0.1
    pos_scale: float = 200
    atom_sampling_std: float = 0.01
    n_atom_samples: int = 10
    rand_rotation: bool = False
    atom_pos_in_query: bool = False
    density_field: DensityFieldBase = field(default_factory=RBFDensityField)
    follow_batch: Optional[List[str]] = field(
        default_factory=lambda: ["enc_pos", "query_pos", "supernode_index"]
    )
    batch_size: int = 4
    num_workers: int = 4
    pin_memory: bool = True

    train_dataset: MDDataset = field(init=False)
    val_dataset: MDDataset = field(init=False)

    def __post_init__(self):
        super().__init__()
        traj = md.load(self.traj_path, top=self.top_path)
        top = traj.topology
        atoms = [atom.element.symbol for atom in top.atoms]

        pos = torch.tensor(
            MinMaxScaler(feature_range=(0.1, 0.9))
            .fit_transform(traj.xyz.reshape(-1, 1))
            .reshape(traj.xyz.shape)
        )

        self.train_dataset = MDDataset(
            pos=pos[:800],
            atoms=atoms,
            n_atom_types=len(self.particle_type_encoding),
            n_points_sample=self.n_points_sample,
            occupancy_radius=self.occupancy_radius,
            n_supernodes=self.n_supernodes,
            supernodes_max_neighbours=self.supernodes_max_neighbours,
            supernodes_radius=self.supernodes_radius,
            pos_scale=self.pos_scale,
            atom_sampling_std=self.atom_sampling_std,
            n_atom_samples=self.n_atom_samples,
            rand_rotation=self.rand_rotation,
            atom_pos_in_query=self.atom_pos_in_query,
            density_field=self.density_field,
        )
        self.val_dataset = MDDataset(
            pos=pos[800:],
            atoms=atoms,
            n_atom_types=len(self.particle_type_encoding),
            n_points_sample=self.n_points_sample,
            occupancy_radius=self.occupancy_radius,
            n_supernodes=self.n_supernodes,
            supernodes_max_neighbours=self.supernodes_max_neighbours,
            supernodes_radius=self.supernodes_radius,
            pos_scale=self.pos_scale,
            atom_sampling_std=self.atom_sampling_std,
            n_atom_samples=self.n_atom_samples,
            rand_rotation=self.rand_rotation,
            atom_pos_in_query=self.atom_pos_in_query,
            density_field=self.density_field,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            follow_batch=self.follow_batch,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            follow_batch=self.follow_batch,
        )
