import os
import argparse
import gc
import numpy as np
import torch
import time
import random
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR
from model.S3Net import S3Net
from SleepDataLoader import SleepDataLoader


def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train_one_epoch(model, loader, criterion, optimizer, device, alpha=1.0):
    model.train()
    running_loss = 0.0
    running_classification_loss = 0.0
    running_gate_loss = 0.0  
    correct = 0
    total = 0
    
    for inputs, targets, stft in loader:
        
        optimizer.zero_grad()
        
        logits, gate_logits = model(inputs, stft)
        
        classification_loss = criterion(logits, targets.argmax(1))
        
        expert_targets = (targets.argmax(1) >= 3).long()
        gate_loss = nn.CrossEntropyLoss()(gate_logits, expert_targets) * alpha

        total_loss = classification_loss + gate_loss
        
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.item() * inputs.size(0)
        running_classification_loss += classification_loss.item() * inputs.size(0)
        running_gate_loss += gate_loss.item() * inputs.size(0)  
        
        preds = logits.argmax(1)
        correct += (preds == targets.argmax(1)).sum().item()
        total += inputs.size(0)
    
    avg_total_loss = running_loss / total
    avg_classification_loss = running_classification_loss / total
    avg_gate_loss = running_gate_loss / total 
    accuracy = correct / total
    
    return {
        'total_loss': avg_total_loss,
        'classification_loss': avg_classification_loss,
        'gate_loss': avg_gate_loss, 
        'accuracy': accuracy
    }

def validate_one_epoch(model, loader, criterion, device, alpha=1.0):
    model.eval()
    running_loss = 0.0
    running_classification_loss = 0.0
    running_gate_loss = 0.0 
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets, stft in loader:
            
            logits, gate_logits = model(inputs, stft)
            
            classification_loss = criterion(logits, targets.argmax(1))
            
            expert_targets = (targets.argmax(1) >= 3).long()
            gate_loss = nn.CrossEntropyLoss()(gate_logits, expert_targets) * alpha  

            total_loss = classification_loss + gate_loss
            
            running_loss += total_loss.item() * inputs.size(0)
            running_classification_loss += classification_loss.item() * inputs.size(0)
            running_gate_loss += gate_loss.item() * inputs.size(0) 
            
            preds = logits.argmax(1)
            correct += (preds == targets.argmax(1)).sum().item()
            total += inputs.size(0)
    
    avg_total_loss = running_loss / total
    avg_classification_loss = running_classification_loss / total
    avg_gate_loss = running_gate_loss / total 
    accuracy = correct / total
    
    return {
        'total_loss': avg_total_loss,
        'classification_loss': avg_classification_loss,
        'gate_loss': avg_gate_loss, 
        'accuracy': accuracy
    }

def train_single_fold(fold_idx, device, epochs=30, data_path='./data_s3', results_dir='./results'):
    
    batch_size = 32
    embed_dim = 64
    lr = 3e-4
    weight_decay = 1e-5
    alpha = 1.0 
    warmup_epochs = 6  

    loader = SleepDataLoader(data_path)  
    train_x, train_y, train_stft, val_x, val_y, val_stft = loader.getFold(fold_idx)
    train_x = train_x.to(device)
    train_y = train_y.to(device)
    train_stft = train_stft.to(device)
    val_x = val_x.to(device)
    val_y = val_y.to(device)
    val_stft = val_stft.to(device)

    train_loader = DataLoader(TensorDataset(train_x, train_y, train_stft), 
                             batch_size=batch_size, num_workers=0, shuffle=True)
    val_loader = DataLoader(TensorDataset(val_x, val_y, val_stft), 
                           batch_size=batch_size, num_workers=0, shuffle=False)
            
    model = S3Net(
        num_classes=5,
        in_chans=10,
        embed_dim=embed_dim,
        depths=[2, 4, 2],
        num_heads=[2, 2, 2],
        window_size=7,
        mlp_ratio=4.,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.1,
        norm_layer=nn.LayerNorm,
        patch_norm=False,
        num_experts=2
    )
    model = model.to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    
    warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs)
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=1e-6)
    
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    fold_dir = os.path.join(results_dir, f"fold_{fold_idx}")
    os.makedirs(fold_dir, exist_ok=True) 
    best_model_path = os.path.join(fold_dir, "best_model.pth")
    
    for epoch in range(epochs):
        start_time = time.time()
        
        train_results = train_one_epoch(model, train_loader, criterion, optimizer, device, alpha)
        val_results = validate_one_epoch(model, val_loader, criterion, device, alpha)
        
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            cosine_scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        end_time = time.time()
        epoch_duration = end_time - start_time
        
        log_message = (f'Fold {fold_idx} - Epoch {epoch+1}/{epochs} - '
                      f'total_loss: {train_results["total_loss"]:.4f} '
                      f'cls_loss: {train_results["classification_loss"]:.4f} '
                      f'gate_loss: {train_results["gate_loss"]:.6f} '  
                      f'acc: {train_results["accuracy"]:.4f} - '
                      f'val_total_loss: {val_results["total_loss"]:.4f} '
                      f'val_cls_loss: {val_results["classification_loss"]:.4f} '
                      f'val_gate_loss: {val_results["gate_loss"]:.6f} ' 
                      f'val_acc: {val_results["accuracy"]:.4f} - '
                      f'time: {epoch_duration:.2f}s - lr: {current_lr:.9f}')
        print(log_message)
        
        if val_results['accuracy'] > best_val_acc:
            best_val_acc = val_results['accuracy']
            torch.save(model.state_dict(), best_model_path)
            print(f"new best val acc : {best_val_acc:.4f}")
            
    del model, train_loader, val_loader, optimizer
    torch.cuda.empty_cache()
    gc.collect()
    

def main():
    parser = argparse.ArgumentParser(description='S3Net training')
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--data_path', type=str, default='./data_s3')
    parser.add_argument('--results_dir', type=str, default='./results')
    parser.add_argument('--device', type=str, default='0')

    args = parser.parse_args()
    seed_everything(42)
    
    if args.device == 'cpu':
        device = 'cpu'
    else:
        device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

    for fold_idx in range(10):
        train_single_fold(
            fold_idx=fold_idx,
            device=device,
            epochs=args.epochs,
            data_path=args.data_path,
            results_dir=args.results_dir
        )
    print("training finished")
        
if __name__ == '__main__':
    main()