import argparse
import os
import time
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import balanced_accuracy_score

from config import ModelConfig
from utilities import DataPreprocessor, compute_checksum, evaluate_model, save_results, get_device
from optimizer import CocktailOptimizer

os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)

class DropoutScheduler:
    """Implements various dropout patterns."""
    def __init__(self, shape: str, max_rate: float, num_layers: int = 3):
        self.shape = shape
        self.max_rate = min(max_rate, 0.9)
        self.num_layers = num_layers
        self.rates = self._get_rates()
    
    def _get_rates(self):
        if self.shape == 'funnel':
            return [min(self.max_rate * (i+1)/self.num_layers, 0.9) 
                   for i in range(self.num_layers)]
        elif self.shape == 'long_funnel':
            return [min(self.max_rate, 0.9)] * (self.num_layers - 1) + [self.max_rate/2]
        elif self.shape == 'diamond':
            mid = self.num_layers // 2
            return [min(self.max_rate * min(i+1, self.num_layers-i)/mid, 0.9) 
                   for i in range(self.num_layers)]
        elif self.shape == 'triangle':
            return [min(self.max_rate * (self.num_layers-i)/self.num_layers, 0.9) 
                   for i in range(self.num_layers)]
        else:  # uniform
            return [min(self.max_rate, 0.9)] * self.num_layers

class RegularizedMLP(nn.Module):
    """MLP with regularization cocktail."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Initialize dropout if active
        self.dropout = None
        if config.use_dropout:
            self.dropout = DropoutScheduler(
                shape=config.dropout_shape,
                max_rate=config.dropout_rate
            )
        
        # Define layer sizes
        self.sizes = [config.input_size, 512, 256, config.num_classes]
        
        # Build layers
        layers = []
        for i in range(len(self.sizes)-1):
            # Linear layer
            linear = nn.Linear(self.sizes[i], self.sizes[i+1])
            nn.init.kaiming_normal_(linear.weight)
            nn.init.zeros_(linear.bias)
            layers.append(linear)
            
            # Batch Normalization
            if config.use_batch_norm and i < len(self.sizes)-2:
                layers.append(nn.BatchNorm1d(self.sizes[i+1]))
            
            # Dropout
            if self.dropout and i < len(self.sizes)-2:
                layers.append(nn.Dropout(p=self.dropout.rates[i]))
            
            # Activation
            if i < len(self.sizes)-2:
                layers.append(nn.ReLU())
        
        self.layers = nn.Sequential(*layers)
        
        # Skip connection
        self.skip_type = None
        if config.use_skip:
            self.skip_type = config.skip_type
            self.shakedrop_prob = config.shakedrop_prob
            
            if config.input_size != config.num_classes:
                self.projection = nn.Linear(config.input_size, config.num_classes)
                nn.init.kaiming_normal_(self.projection.weight)
                nn.init.zeros_(self.projection.bias)
            else:
                self.projection = nn.Identity()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.skip_type:
            identity = self.projection(x)
            out = self.layers(x)
            
            if self.skip_type == 'ShakeShake' and self.training:
                # Ensure alpha has the same dtype as the input tensors
                alpha = torch.rand(1, device=x.device, dtype=identity.dtype)
                # Ensure all tensors have the same dtype before lerp
                identity = identity.to(out.dtype)
                return torch.lerp(identity, out, alpha)
            elif self.skip_type == 'ShakeDrop' and self.training:
                if torch.rand(1).item() < self.shakedrop_prob:
                    # Ensure same dtype for addition
                    identity = identity.to(out.dtype)
                    return out + identity
                return identity
            else:  # Standard skip connection
                # Ensure same dtype for addition
                identity = identity.to(out.dtype)
                return out + identity
        
        return self.layers(x)
    
def train_epoch(
    model: RegularizedMLP,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    config: ModelConfig,
    scaler: GradScaler = None
) -> tuple[float, float]:
    model.train()
    total_loss = 0
    predictions = []
    targets_all = []
    
    num_classes = config.num_classes
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Remove manual precision conversion
        inputs = inputs.to(torch.float32)  # Always keep inputs in float32
        
        optimizer.zero_grad()
        
        try:
            if config.use_amp and scaler is not None:
                with autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                
                # Scale loss and backpropagate
                scaler.scale(loss).backward()
                
                # Unscale gradients and clip
                if not torch.isnan(loss):  # Add check for NaN loss
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                
                # Step and update
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training without mixed precision
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                optimizer.step()
            
            total_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1).clamp(0, num_classes - 1).cpu().numpy()
            predictions.extend(preds)
            targets_all.extend(targets.cpu().numpy())
        
        except Exception as e:
            print(f"Training error: {e}")
            print(f"Config: {config}")
            raise
    
    return (
        total_loss / len(train_loader.dataset),
        balanced_accuracy_score(targets_all, predictions)
    )
    
def train_mlp_cocktail(task_id: int, config: dict, epochs: int = 100, device: str = 'cpu', data_root: str = None) -> dict:
    # Load and preprocess data
    preprocessor = DataPreprocessor(data_root=data_root)
    X_train, X_test, y_train, y_test = preprocessor.get_data(task_id)
    
    # Ensure consistent dtype
    dtype = torch.float16 if config.get('use_amp', False) else torch.float32
    
    train_loader = DataLoader(
        TensorDataset(
            torch.tensor(X_train, dtype=torch.float32),
            torch.tensor(y_train, dtype=torch.long)
        ),
        batch_size=128,
        shuffle=True,
        pin_memory=True
    )
    test_loader = DataLoader(
        TensorDataset(
            torch.tensor(X_test, dtype=torch.float32),
            torch.tensor(y_test, dtype=torch.long)
        ),
        batch_size=128,
        pin_memory=True
    )
    
    # Initialize model and training components
    device = torch.device(device)
    model_config = ModelConfig(
        input_size=X_train.shape[1],
        num_classes=len(np.unique(y_train)),
        **config
    )
    model_config.validate()
    
    # Ensure model is in the correct dtype
    model = RegularizedMLP(model_config).to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=model_config.learning_rate,
        weight_decay=config.get('weight_decay', 0)
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    # More robust scaler initialization
    scaler = GradScaler() if config.get('use_amp', False) and device.type == 'cuda' else None
    
    # train
    best_val_acc = 0
    best_state = None
    best_checksum = None
    patience = 100
    patience_counter = 0
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion,
            optimizer, device, model_config, scaler
        )
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = evaluate_model(model, test_loader, criterion, device)
            scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            best_checksum = compute_checksum(best_state)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break
    
    # Load best model and evaluate
    if best_state is not None:
        model.load_state_dict(best_state)
    
    model.eval()
    with torch.no_grad():
        test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)
    
    return {
        'test_balanced_accuracy': float(test_acc),
        'val_balanced_accuracy': float(best_val_acc),
        'epochs_trained': epoch + 1,
        'model_checksum': best_checksum
    }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_id', type=int, required=True)
    parser.add_argument('--n_trials', type=int, default=100)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--output_dir', type=str, default='./runs')
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--storage', type=str, default=None)
    parser.add_argument('--data_root', type=str, default=None, 
                        help='Root directory containing dataset folders')
    args = parser.parse_args()
    
    start_time = time.time()
    device = get_device(args.device)
    print(f'Using device: {device}')
    
    # Run optimization
    optimizer = CocktailOptimizer(
        train_fn=lambda task_id, config, epochs: train_mlp_cocktail(
            task_id=task_id,
            config=config,
            epochs=epochs,
            device=device,
            data_root=args.data_root
        ),
        task_id=args.task_id,
        max_epochs=args.epochs,
        n_trials=args.n_trials,
        storage=args.storage,
        output_dir=args.output_dir
    )
    
    results = optimizer.optimize()
    
    # Add timing information
    total_time = time.time() - start_time
    hours = int(total_time // 3600)
    minutes = int((total_time % 3600) // 60)
    seconds = int(total_time % 60)
    
    results['timing'] = {
        'total_seconds': total_time,
        'formatted_time': f"{hours:02d}:{minutes:02d}:{seconds:02d}"
    }
    
    # Save results
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_path = os.path.join(
        args.output_dir,
        f'task_{args.task_id}_{timestamp}.json'
    )
    optimizer.save_results(results, save_path)
    
    print(f"\nOptimization completed in {hours:02d}:{minutes:02d}:{seconds:02d}")
    print(f"Results saved to {save_path}")
    print(f"Best test accuracy: {results['best_result']['test_balanced_accuracy']:.4f}")

if __name__ == '__main__':
    main()