"""
Crystal String Binary Classification using Transformers
Classifies whether a crystal string was permuted or not
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import wandb
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import argparse
from pathlib import Path

from my_access_data import get_crystal_string, CifDataset

def get_model_options():
    """Return available model options with their sizes"""
    return {
        "nano": {
            "name": "microsoft/DialoGPT-tiny",  # ~33M parameters
            "description": "Nano model (~33M params) - SMALLEST"
        },
        "micro": {
            "name": "distilbert-base-uncased",  # ~66M parameters
            "description": "Micro model (~66M params)"
        },
        "tiny": {
            "name": "bert-base-uncased",  # ~110M parameters
            "description": "Tiny model (~110M params)"
        },
        "small": {
            "name": "roberta-base",  # ~125M parameters
            "description": "Small model (~125M params)"
        },
        "medium": {
            "name": "roberta-large",  # ~355M parameters
            "description": "Medium model (~355M params)"
        },
        "large": {
            "name": "microsoft/DialoGPT-medium",  # ~774M parameters
            "description": "Large model (~774M params)"
        }
    }

class CrystalClassificationDataset(Dataset):
    """Dataset for binary classification of permuted vs non-permuted crystal strings"""
    
    def __init__(self, csv_fn, tokenizer, max_length=512, balance_classes=True):
        super().__init__()
        
        # Load original dataset
        self.original_dataset = CifDataset(csv_fn)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Create balanced dataset with 50% permuted, 50% non-permuted
        self.data = []
        self.labels = []
        
        for i in range(len(self.original_dataset)):
            # Get original crystal string (non-permuted)
            original_crystal_str = self.original_dataset[i]
            
            # Get permuted crystal string
            input_dict = self.original_dataset.inputs[i]
            k = 'cif' if 'cif' in input_dict else 'cif_str'
            permuted_crystal_str = get_crystal_string(input_dict[k], permute_atoms=True)
            
            # Add both to dataset
            self.data.append(original_crystal_str)
            self.labels.append(0)  # 0 = non-permuted
            
            self.data.append(permuted_crystal_str)
            self.labels.append(1)  # 1 = permuted
        
        # Shuffle the data
        indices = np.random.permutation(len(self.data))
        self.data = [self.data[i] for i in indices]
        self.labels = [self.labels[i] for i in indices]
        
        print(f"Dataset created with {len(self.data)} samples")
        print(f"Class distribution: {np.bincount(self.labels)}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        crystal_str = self.data[idx]
        label = self.labels[idx]
        
        # Tokenize the crystal string
        encoding = self.tokenizer(
            crystal_str,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

class CrystalClassifier(pl.LightningModule):
    """PyTorch Lightning module for crystal classification"""
    
    def __init__(self, model_name="distilbert-base-uncased", learning_rate=2e-5):
        super().__init__()
        self.save_hyperparameters()
        
        # Load pre-trained transformer model
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name, 
            num_labels=2
        )
        
        self.learning_rate = learning_rate
        
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs
    
    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == batch['labels']).float().mean()
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == batch['labels']).float().mean()
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

def setup_data_module(data_path, batch_size=16, max_length=512, model_name="distilbert-base-uncased"):
    """Setup PyTorch Lightning DataModule"""
    
    class CrystalDataModule(pl.LightningDataModule):
        def __init__(self, data_path, batch_size, max_length, model_name):
            super().__init__()
            self.data_path = data_path
            self.batch_size = batch_size
            self.max_length = max_length
            self.model_name = model_name
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            
            # Add padding token if not present
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        
        def setup(self, stage=None):
            if stage == 'fit' or stage is None:
                # Load full dataset
                full_dataset = CrystalClassificationDataset(
                    str(self.data_path / "train.csv"),
                    self.tokenizer,
                    self.max_length
                )
                
                # Split into train/val
                train_size = int(0.8 * len(full_dataset))
                val_size = len(full_dataset) - train_size
                self.train_dataset, self.val_dataset = random_split(
                    full_dataset, [train_size, val_size]
                )
                
                print(f"Train size: {len(self.train_dataset)}")
                print(f"Val size: {len(self.val_dataset)}")
        
        def train_dataloader(self):
            return DataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=4
            )
        
        def val_dataloader(self):
            return DataLoader(
                self.val_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=4
            )
    
    return CrystalDataModule(data_path, batch_size, max_length, model_name)

def main(args):
    # Get model name if using preset
    if args.model_size in get_model_options():
        model_name = get_model_options()[args.model_size]["name"]
        print(f"Using {args.model_size} model: {model_name}")
    else:
        model_name = args.model_name
    
    # Initialize wandb
    wandb_logger = WandbLogger(
        project="crystal-classification",
        name=args.run_name,
        log_model=True
    )
    
    # Setup data
    data_module = setup_data_module(
        args.data_path,
        batch_size=args.batch_size,
        max_length=args.max_length,
        model_name=model_name
    )
    
    # Setup model
    model = CrystalClassifier(
        model_name=model_name,
        learning_rate=args.learning_rate
    )
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        dirpath=f'checkpoints/{args.run_name}',
        filename='best-{epoch:02d}-{val_acc:.3f}',
        save_top_k=3,
        mode='max'
    )
    
    # Setup trainer
    trainer = pl.Trainer(
        max_epochs=args.num_epochs,
        logger=wandb_logger,
        callbacks=[checkpoint_callback],
        accelerator='auto',
        devices=1,
        precision=16 if args.use_amp else 32,
        gradient_clip_val=1.0,
        accumulate_grad_batches=args.grad_accum
    )
    
    # Train
    trainer.fit(model, data_module)
    
    # Test on validation set
    trainer.validate(model, data_module)
    
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--data-path", type=Path, default="data/basic")
    
    # Model selection
    model_options = get_model_options()
    parser.add_argument("--model-size", type=str, 
                       choices=list(model_options.keys()) + ["custom"],
                       default="small",
                       help="Model size preset")
    parser.add_argument("--model-name", type=str, default="distilbert-base-uncased",
                       help="Custom model name (used if model-size=custom)")
    
    # Training parameters
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--max-length", type=int, default=512)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--num-epochs", type=int, default=10)
    parser.add_argument("--grad-accum", type=int, default=1)
    parser.add_argument("--use-amp", action="store_true", default=False)
    
    args = parser.parse_args()
    
    # Print model options if requested
    if args.model_size == "list":
        print("Available model sizes:")
        for size, info in model_options.items():
            print(f"  {size}: {info['name']} - {info['description']}")
        exit(0)
    
    main(args) 