import json
import pickle
import csv
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
import pandas as pd
from rdkit import Chem
from ...core.representation import MolecularGraph, Motif, Connection, ConnectionSite


class DataLoader:
    @staticmethod
    def load_smiles_dataset(file_path: Union[str, Path], smiles_column: str = 'smiles',
                           properties_columns: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        file_path = Path(file_path)

        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")

        dataset = []

        if file_path.suffix.lower() == '.csv':
            df = pd.read_csv(file_path)

            for _, row in df.iterrows():
                entry = {'smiles': row[smiles_column]}

                if properties_columns:
                    properties = {}
                    for prop_col in properties_columns:
                        if prop_col in row:
                            properties[prop_col] = row[prop_col]
                    entry['properties'] = properties

                dataset.append(entry)

        elif file_path.suffix.lower() == '.json':
            with open(file_path, 'r') as f:
                data = json.load(f)

            if isinstance(data, list):
                for item in data:
                    if isinstance(item, dict) and smiles_column in item:
                        entry = {'smiles': item[smiles_column]}
                        if properties_columns:
                            properties = {}
                            for prop_col in properties_columns:
                                if prop_col in item:
                                    properties[prop_col] = item[prop_col]
                            entry['properties'] = properties
                        dataset.append(entry)

        elif file_path.suffix.lower() == '.smi':
            with open(file_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        parts = line.split()
                        if parts:
                            entry = {'smiles': parts[0]}
                            # If there's a second column, treat it as molecule name
                            if len(parts) > 1:
                                entry['name'] = parts[1]
                            dataset.append(entry)

        return dataset

    @staticmethod
    def save_smiles_dataset(dataset: List[Dict[str, Any]], file_path: Union[str, Path],
                           format: str = 'csv'):
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)

        if format.lower() == 'csv':
            # Convert to DataFrame
            df = pd.json_normalize(dataset)
            df.to_csv(file_path, index=False)

        elif format.lower() == 'json':
            with open(file_path, 'w') as f:
                json.dump(dataset, f, indent=2)

        elif format.lower() == 'smi':
            with open(file_path, 'w') as f:
                for entry in dataset:
                    smiles = entry.get('smiles', '')
                    name = entry.get('name', '')
                    if smiles:
                        if name:
                            f.write(f"{smiles}\t{name}\n")
                        else:
                            f.write(f"{smiles}\n")

    @staticmethod
    def validate_smiles_dataset(dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
        valid_count = 0
        invalid_smiles = []

        for i, entry in enumerate(dataset):
            smiles = entry.get('smiles', '')
            mol = Chem.MolFromSmiles(smiles)

            if mol is not None:
                valid_count += 1
            else:
                invalid_smiles.append({
                    'index': i,
                    'smiles': smiles,
                    'entry': entry
                })

        return {
            'total_molecules': len(dataset),
            'valid_molecules': valid_count,
            'invalid_molecules': len(invalid_smiles),
            'validity_rate': valid_count / len(dataset) if dataset else 0.0,
            'invalid_entries': invalid_smiles[:10]  # Show first 10 invalid entries
        }

    @staticmethod
    def load_motif_library(file_path: Union[str, Path]) -> List[Motif]:
        file_path = Path(file_path)

        if not file_path.exists():
            raise FileNotFoundError(f"Motif library not found: {file_path}")

        if file_path.suffix.lower() == '.json':
            with open(file_path, 'r') as f:
                data = json.load(f)

            motifs = []
            for motif_data in data:
                # Reconstruct ConnectionSite objects
                connection_sites = []
                for site_data in motif_data.get('connection_sites', []):
                    site = ConnectionSite(
                        site_id=site_data['site_id'],
                        atom_idx=site_data['atom_idx'],
                        site_type=site_data['site_type'],
                        chemical_environment=site_data['chemical_environment'],
                        allowed_bond_types=set(site_data['allowed_bond_types']),
                        is_aromatic=site_data.get('is_aromatic', False)
                    )
                    connection_sites.append(site)

                # Create mol object from SMILES
                mol = Chem.MolFromSmiles(motif_data['smiles']) if motif_data['smiles'] else None

                motif = Motif(
                    motif_id=motif_data['motif_id'],
                    smiles=motif_data['smiles'],
                    mol=mol,
                    connection_sites=connection_sites,
                    properties=motif_data.get('properties', {}),
                    is_aromatic=motif_data.get('is_aromatic', False),
                    ring_info=motif_data.get('ring_info', {}),
                    functional_groups=motif_data.get('functional_groups', [])
                )
                motifs.append(motif)

            return motifs

        elif file_path.suffix.lower() == '.pkl':
            with open(file_path, 'rb') as f:
                return pickle.load(f)

        else:
            raise ValueError(f"Unsupported file format: {file_path.suffix}")

    @staticmethod
    def save_motif_library(motifs: List[Motif], file_path: Union[str, Path]):
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)

        if file_path.suffix.lower() == '.json':
            # Convert motifs to serializable format
            motifs_data = []
            for motif in motifs:
                # Convert ConnectionSite objects to dicts
                sites_data = []
                for site in motif.connection_sites:
                    site_data = {
                        'site_id': site.site_id,
                        'atom_idx': site.atom_idx,
                        'site_type': site.site_type,
                        'chemical_environment': site.chemical_environment,
                        'allowed_bond_types': list(site.allowed_bond_types),
                        'is_aromatic': site.is_aromatic
                    }
                    sites_data.append(site_data)

                motif_data = {
                    'motif_id': motif.motif_id,
                    'smiles': motif.smiles,
                    'connection_sites': sites_data,
                    'properties': motif.properties,
                    'is_aromatic': motif.is_aromatic,
                    'ring_info': motif.ring_info,
                    'functional_groups': motif.functional_groups
                }
                motifs_data.append(motif_data)

            with open(file_path, 'w') as f:
                json.dump(motifs_data, f, indent=2)

        elif file_path.suffix.lower() == '.pkl':
            with open(file_path, 'wb') as f:
                pickle.dump(motifs, f)

        else:
            raise ValueError(f"Unsupported file format: {file_path.suffix}")

    @staticmethod
    def load_molecular_graph(file_path: Union[str, Path]) -> MolecularGraph:
        file_path = Path(file_path)

        with open(file_path, 'r') as f:
            data = json.load(f)

        # Load motifs
        motifs = []
        for motif_data in data['motifs']:
            connection_sites = []
            for site_data in motif_data['connection_sites']:
                site = ConnectionSite(
                    site_id=site_data['site_id'],
                    atom_idx=site_data['atom_idx'],
                    site_type=site_data['site_type'],
                    chemical_environment=site_data['chemical_environment'],
                    allowed_bond_types=set(site_data['allowed_bond_types']),
                    is_aromatic=site_data.get('is_aromatic', False)
                )
                connection_sites.append(site)

            mol = Chem.MolFromSmiles(motif_data['smiles']) if motif_data['smiles'] else None

            motif = Motif(
                motif_id=motif_data['motif_id'],
                smiles=motif_data['smiles'],
                mol=mol,
                connection_sites=connection_sites,
                properties=motif_data.get('properties', {}),
                is_aromatic=motif_data.get('is_aromatic', False),
                ring_info=motif_data.get('ring_info', {}),
                functional_groups=motif_data.get('functional_groups', [])
            )
            motifs.append(motif)

        # Create molecular graph
        mol_graph = MolecularGraph(motifs)

        # Add connections
        for conn_data in data.get('connections', []):
            connection = Connection(
                source_motif=conn_data['source_motif'],
                source_site=conn_data['source_site'],
                target_motif=conn_data['target_motif'],
                target_site=conn_data['target_site'],
                bond_type=conn_data['bond_type']
            )
            mol_graph.add_connection(connection)

        return mol_graph

    @staticmethod
    def save_molecular_graph(mol_graph: MolecularGraph, file_path: Union[str, Path]):
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)

        # Convert to serializable format
        data = {
            'motifs': [],
            'connections': []
        }

        # Serialize motifs
        for motif in mol_graph.motifs.values():
            sites_data = []
            for site in motif.connection_sites:
                site_data = {
                    'site_id': site.site_id,
                    'atom_idx': site.atom_idx,
                    'site_type': site.site_type,
                    'chemical_environment': site.chemical_environment,
                    'allowed_bond_types': list(site.allowed_bond_types),
                    'is_aromatic': site.is_aromatic
                }
                sites_data.append(site_data)

            motif_data = {
                'motif_id': motif.motif_id,
                'smiles': motif.smiles,
                'connection_sites': sites_data,
                'properties': motif.properties,
                'is_aromatic': motif.is_aromatic,
                'ring_info': motif.ring_info,
                'functional_groups': motif.functional_groups
            }
            data['motifs'].append(motif_data)

        # Serialize connections
        for connection in mol_graph.connections:
            conn_data = {
                'source_motif': connection.source_motif,
                'source_site': connection.source_site,
                'target_motif': connection.target_motif,
                'target_site': connection.target_site,
                'bond_type': connection.bond_type
            }
            data['connections'].append(conn_data)

        with open(file_path, 'w') as f:
            json.dump(data, f, indent=2)

    @staticmethod
    def load_training_config(file_path: Union[str, Path]) -> Dict[str, Any]:
        file_path = Path(file_path)

        with open(file_path, 'r') as f:
            if file_path.suffix.lower() == '.json':
                return json.load(f)
            else:
                raise ValueError(f"Unsupported config format: {file_path.suffix}")

    @staticmethod
    def save_training_config(config: Dict[str, Any], file_path: Union[str, Path]):
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)

        with open(file_path, 'w') as f:
            json.dump(config, f, indent=2)

    @staticmethod
    def load_results(file_path: Union[str, Path]) -> Dict[str, Any]:
        file_path = Path(file_path)

        if file_path.suffix.lower() == '.json':
            with open(file_path, 'r') as f:
                return json.load(f)
        elif file_path.suffix.lower() == '.pkl':
            with open(file_path, 'rb') as f:
                return pickle.load(f)
        else:
            raise ValueError(f"Unsupported results format: {file_path.suffix}")

    @staticmethod
    def save_results(results: Dict[str, Any], file_path: Union[str, Path]):
        file_path = Path(file_path)
        file_path.parent.mkdir(parents=True, exist_ok=True)

        if file_path.suffix.lower() == '.json':
            with open(file_path, 'w') as f:
                json.dump(results, f, indent=2, default=str)  # default=str for non-serializable objects
        elif file_path.suffix.lower() == '.pkl':
            with open(file_path, 'wb') as f:
                pickle.dump(results, f)
        else:
            raise ValueError(f"Unsupported results format: {file_path.suffix}")

    @staticmethod
    def create_dataset_split(dataset: List[Any], train_ratio: float = 0.8,
                           val_ratio: float = 0.1, test_ratio: float = 0.1,
                           random_seed: int = 42) -> Dict[str, List[Any]]:
        import random

        if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
            raise ValueError("Split ratios must sum to 1.0")

        # Shuffle dataset
        dataset_copy = dataset.copy()
        random.seed(random_seed)
        random.shuffle(dataset_copy)

        n_total = len(dataset_copy)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        n_test = n_total - n_train - n_val

        splits = {
            'train': dataset_copy[:n_train],
            'val': dataset_copy[n_train:n_train + n_val],
            'test': dataset_copy[n_train + n_val:n_train + n_val + n_test]
        }

        return splits