"""Training and evaluation functions."""

import json
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from collections import defaultdict

from metrics import compute_dice_score
from src.dataset import (
    MedVLSMDataset, 
    get_dataloader, 
    preprocess_batch, 
    save_model_checkpoint,
    serialize_config
)


def evaluate_task(model, processor, data_root, dataset_name, device, config, prompt_strategy, task_id):
    dataset = MedVLSMDataset(data_root, dataset_name, 'val', prompt_strategy, config)
    loader = get_dataloader(dataset, config['eval_batch_size'], shuffle=False, num_workers=config['num_workers'], seed=42)

    modality_id = model.modality_manager.task_modality_id.get(task_id, 0)
    model.set_modality(modality_id)
    
    model.eval()
    total_dice, total_loss, n, nb = 0.0, 0.0, 0, 0
    with torch.no_grad():
        for batch in loader:
            images = batch['image'].to(device, non_blocking=True)
            masks = batch['mask'].to(device, non_blocking=True).long()
            inputs = preprocess_batch(images, batch['prompt'], processor, device, config)
            outputs = model(inputs, prompts=batch['prompt'], task_id=task_id)
            
            logits = outputs['logits']
            total_loss += F.cross_entropy(logits, masks).item()
            nb += 1
            
            preds = torch.argmax(logits, dim=1)
            for i in range(preds.shape[0]):
                total_dice += compute_dice_score(preds[i], masks[i])
                n += 1
                
    return total_dice / max(n, 1), total_loss / max(nb, 1)


def train_step(model, batch, optimizer, config, task_id):
    device = config['device']
    images = batch['image'].to(device, non_blocking=True)
    masks = batch['mask'].to(device, non_blocking=True).long()
    prompts = batch['prompt']

    inputs = preprocess_batch(images, prompts, model.processor, device, config)
    outputs = model(inputs, prompts=prompts, task_id=task_id)
    logits = outputs['logits']

    seg_loss = F.cross_entropy(logits, masks)
    preds_soft = F.softmax(logits, dim=1)[:, 1]
    true_mask = (masks == 1).float()
    intersection = (preds_soft * true_mask).sum(dim=(1, 2))
    union = preds_soft.sum(dim=(1, 2)) + true_mask.sum(dim=(1, 2))
    dice_loss = (1 - 2.0 * intersection / (union + 1e-8)).mean()

    ewc_loss = model.get_ewc_loss()

    total = (config['loss_weights']['segmentation'] * seg_loss +
             config['loss_weights']['dice'] * dice_loss +
             config['loss_weights']['ewc'] * ewc_loss)

    optimizer.zero_grad()
    total.backward()
    torch.nn.utils.clip_grad_norm_(model.get_trainable_parameters(), config['gradient_clip_norm'])
    optimizer.step()

    return {
        'total': total.item(), 
        'seg': seg_loss.item(), 
        'dice': dice_loss.item(),
        'ewc': ewc_loss.item() if torch.is_tensor(ewc_loss) else ewc_loss
    }


def save_tsne_data(model, datasets, config, order_name):
    all_embeddings = []
    all_task_names = []
    all_modality_ids = []
    
    prompt_strategy = config.get('prompt_strategy', 'basic')
    
    for task_idx, ds_info in enumerate(datasets):
        task_name = ds_info['name']
        modality_id = model.modality_manager.task_modality_id.get(task_idx, 0)
        
        train_ds = MedVLSMDataset(config['data_root'], task_name, 'train', prompt_strategy, config)
        
        count = 0
        max_samples = 200
        for i in range(len(train_ds)):
            if count >= max_samples:
                break
            prompt = train_ds[i]['prompt']
            
            emb = model.modality_manager.extract_prompt_embedding(prompt)
            all_embeddings.append(emb.cpu().numpy())
            all_task_names.append(task_name)
            all_modality_ids.append(modality_id)
            count += 1
    
    embeddings = np.stack(all_embeddings)
    
    save_path = Path(config['output_dir']) / order_name / 'tsne_data.npz'
    save_path.parent.mkdir(parents=True, exist_ok=True)
    
    np.savez(
        save_path,
        embeddings=embeddings,
        task_names=all_task_names,
        modality_ids=all_modality_ids,
        alpha=model.modality_manager.alpha
    )
    
    print(f"\nt-SNE data saved: {save_path}, total samples: {len(all_embeddings)}")


def train_model(model, processor, config, datasets, order_name):
    print(f"\n{'='*60}\nTraining: {order_name.upper()} (Adaptive CRP + LoRA + EWC)\n{'='*60}")

    performance_matrix = {}
    task_results = {}
    checkpoints_dir = Path(config['output_dir']) / 'checkpoints' / order_name
    checkpoints_dir.mkdir(parents=True, exist_ok=True)
    
    prompt_strategy = config.get('prompt_strategy', 'basic')
    best_task_performance = {}
    
    for task_idx, ds_info in enumerate(datasets):
        print(f"\n--- Task {task_idx + 1}/{len(datasets)}: {ds_info['name']} ---")

        train_ds = MedVLSMDataset(config['data_root'], ds_info['name'], 'train', prompt_strategy, config)
        train_loader = get_dataloader(train_ds, config['batch_size'], shuffle=True, num_workers=config['num_workers'], seed=42)

        task_id, modality_id = model.start_new_task(train_ds[0]['prompt'], ds_info['name'], train_loader)
        
        optimizer = torch.optim.AdamW(
            model.get_trainable_parameters(modality_id), 
            lr=config['learning_rate'], 
            weight_decay=config['weight_decay']
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['max_epochs_per_task'])

        info = model.get_analysis_info()
        print(f"Trainable: LoRA_M{modality_id}={info['lora_params_per_modality']:,}, Task={info['task_params']:,}")

        task_losses, val_dices = [], []
        best_dice, patience_cnt = 0.0, 0
        
        for epoch in range(config['max_epochs_per_task']):
            model.train()
            losses = {'total': 0, 'seg': 0, 'dice': 0, 'ewc': 0}
            for batch in train_loader:
                step_losses = train_step(model, batch, optimizer, config, task_id)
                for k in losses:
                    losses[k] += step_losses[k]
            for k in losses:
                losses[k] /= len(train_loader)
            task_losses.append(losses['total'])
            scheduler.step()

            val_dice, val_loss = evaluate_task(model, processor, config['data_root'], ds_info['name'], config['device'], config, prompt_strategy, task_id)
            val_dices.append(val_dice)

            if (epoch + 1) % config['print_interval'] == 0 or epoch == 0:
                print(f"  E{epoch+1}: loss={losses['total']:.4f} (seg={losses['seg']:.3f}, dice={losses['dice']:.3f}, ewc={losses['ewc']:.3f}) | val={val_dice:.4f}")

            if val_dice > best_dice + config['min_delta']:
                best_dice, patience_cnt = val_dice, 0
            else:
                patience_cnt += 1

            if epoch >= config['min_epochs'] and patience_cnt >= config['patience']:
                print(f"  Early stop at epoch {epoch+1}")
                break

        model.finish_task(train_loader, config['device'], task_id, modality_id)

        task_results[ds_info['name']] = {
            'final_loss': task_losses[-1] if task_losses else float('inf'),
            'val_dices': val_dices,
            'best_val_dice': best_dice,
            'total_epochs': len(task_losses),
            'modality': modality_id
        }

        print(f"\nAll tasks evaluation:")
        perf = {}
        for eval_tid in range(task_idx + 1):
            dice, _ = evaluate_task(model, processor, config['data_root'], datasets[eval_tid]['name'], config['device'], config, prompt_strategy, eval_tid)
            task_name = datasets[eval_tid]['name']
            task_modality = model.modality_manager.task_modality_id.get(eval_tid, -1)
            perf[task_name] = dice
            print(f"  {task_name} (M{task_modality}): {dice:.4f}")
            
            if task_name not in best_task_performance:
                best_task_performance[task_name] = dice
            else:
                best_task_performance[task_name] = max(best_task_performance[task_name], dice)

        if task_idx > 0:
            forgetting_values = []
            for i in range(task_idx):
                task_name = datasets[i]['name']
                best_perf = best_task_performance[task_name]
                curr_perf = perf[task_name]
                forgetting_values.append(max(0, best_perf - curr_perf))
            forgetting = np.mean(forgetting_values)
            print(f"  Forgetting: {forgetting:.4f}")

        performance_matrix[f"after_task_{task_idx+1}"] = perf

    analysis = model.get_analysis_info()
    modality_info = analysis['modality_analysis']
    
    final_perfs = {}
    print(f"\n{'='*60}\nFINAL RESULTS\n{'='*60}")
    
    modality_tasks = defaultdict(list)
    for i, ds in enumerate(datasets):
        task_modality = model.modality_manager.task_modality_id.get(i, -1)
        modality_tasks[task_modality].append((i, ds['name']))
    
    print("\n--- Results by Modality ---")
    for mid in sorted(modality_tasks.keys()):
        prompts = model.modality_manager.modality_prompts.get(mid, [])
        print(f"\nModality {mid}:")
        print(f"  Representative prompts: {prompts[:3]}")
        
        for tid, task_name in modality_tasks[mid]:
            dice, loss = evaluate_task(model, processor, config['data_root'], task_name, config['device'], config, prompt_strategy, tid)
            final_perfs[task_name] = {'dice_score': dice, 'loss': loss, 'modality': mid}
            print(f"  - {task_name}: {dice:.4f}")

    avg_dice = np.mean([p['dice_score'] for p in final_perfs.values()])
    first_perfs = [performance_matrix[f"after_task_{i+1}"].get(datasets[i]['name'], 0) for i in range(len(datasets))]
    final_list = [final_perfs[datasets[i]['name']]['dice_score'] for i in range(len(datasets))]
    avg_forgetting = np.mean([max(0, f - l) for f, l in zip(first_perfs, final_list)])

    print(f"\n--- Summary ---")
    print(f"  Average Dice: {avg_dice:.4f}")
    print(f"  Average Forgetting: {avg_forgetting:.4f}")
    print(f"  Number of modalities: {modality_info['num_modalities']}")
    print(f"  CRP alpha: {model.modality_manager.alpha}")
    print(f"  Trainable params: {analysis['trainable_params']:,}")
    print(f"  Method: Adaptive-CRP + LoRA + EWC")
    
    print(f"\n--- Learned Similarity Distributions ---")
    print(f"  Intra-modality: mean={modality_info['intra_sim_stats']['mean']:.3f}, std={modality_info['intra_sim_stats']['std']:.3f}, n={modality_info['intra_sim_stats']['n']}")
    print(f"  Inter-modality: mean={modality_info['inter_sim_stats']['mean']:.3f}, std={modality_info['inter_sim_stats']['std']:.3f}, n={modality_info['inter_sim_stats']['n']}")

    save_tsne_data(model, datasets, config, order_name)

    if config.get('save_checkpoints', True):
        final_checkpoint_path = checkpoints_dir / "final_model"
        final_checkpoint_info = {
            "order": order_name,
            "num_tasks": len(datasets),
            "tasks_trained": [d['name'] for d in datasets],
            "final_performances": final_perfs,
            "average_dice": avg_dice,
            "average_forgetting": avg_forgetting,
            "performance_matrix": performance_matrix,
            "method": "Adaptive-CRP + LoRA + EWC",
            "crp_alpha": model.modality_manager.alpha,
            "trainable_params": analysis['trainable_params'],
            "total_params": analysis['total_params'],
            "modality_analysis": modality_info
        }
        save_model_checkpoint(model, processor, final_checkpoint_path, final_checkpoint_info, serialize_config(config))
        print(f"\nFinal model saved: {final_checkpoint_path}")

    results_dir = Path(config['output_dir']) / order_name
    results_dir.mkdir(parents=True, exist_ok=True)
    
    with open(results_dir / 'modality_analysis.json', 'w') as f:
        json.dump(modality_info, f, indent=2, default=lambda x: float(x) if isinstance(x, (np.floating, np.integer)) else str(x))

    return performance_matrix, task_results, avg_forgetting, analysis