

import os
import glob
import yaml
import math
import random
import logging
import argparse
import csv
from datetime import datetime
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR


from torch.utils.tensorboard import SummaryWriter


import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


from torch_geometric.data import Batch


from torch.cuda.amp import GradScaler, autocast



from interfacediff.models.interfacediff import InterFaceDiff


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except Exception:
        pass
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def make_dirs(d):
    os.makedirs(d, exist_ok=True)


class PtDataset(Dataset):
    
    def __init__(self, files):
        self.files = files

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        p = self.files[idx]
        
        sample = torch.load(p)

        return sample

def collate_fn(samples):
    
    ids = [s.get('id', f'idx_{i}') for i, s in enumerate(samples)]
    c1_list = [s['chain_1_graph'] for s in samples]
    c2_list = [s['chain_2_graph'] for s in samples]
    iface_list = [s['interface_graph'] for s in samples]

    batch = {
        'id': ids,
        'chain_1_graph': Batch.from_data_list(c1_list),
        'chain_2_graph': Batch.from_data_list(c2_list),
        'interface_graph': Batch.from_data_list(iface_list),
    }

    return batch


def aggregate_loss_dict(loss_dict):
    
    total = None
    scalar_dict = {}
    for k, v in loss_dict.items():
        if isinstance(v, torch.Tensor):
            loss_tensor = v.mean()
        else:
            
            loss_tensor = torch.tensor(float(v), device='cuda' if torch.cuda.is_available() else 'cpu')
        if total is None:
            total = loss_tensor
        else:
            total = total + loss_tensor
        scalar_dict[k] = float(loss_tensor.detach().cpu().item())
    if total is None:
        raise ValueError("Model returned empty loss dict")
    return total, scalar_dict


def train_one_epoch(model, loader, optimizer, scaler, device, cfg, epoch, writer=None, logger=None):
    model.train()
    running = {}
    it = 0
    pbar = tqdm(loader, desc=f"Train Epoch {epoch}", leave=False)
    for batch in pbar:
        it += 1
        
        for k in ['chain_1_graph', 'chain_2_graph', 'interface_graph']:
            if k in batch and hasattr(batch[k], 'to'):
                batch[k] = batch[k].to(device)

        optimizer.zero_grad()
        with autocast(enabled=cfg['amp']['enabled'] and device.type == 'cuda'):
            loss_dict = model(batch)  
            total_loss, scalar_dict = aggregate_loss_dict(loss_dict)

        
        if cfg['amp']['enabled'] and device.type == 'cuda':
            scaler.scale(total_loss).backward()
            if cfg['train']['grad_clip'] is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['train']['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            total_loss.backward()
            if cfg['train']['grad_clip'] is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg['train']['grad_clip'])
            optimizer.step()

        
        for k, v in scalar_dict.items():
            running.setdefault(k, 0.0)
            running[k] += v

        avg_loss = running['loss'] / it if 'loss' in running else sum(running.values()) / it
        pbar.set_postfix({'avg_loss': f'{avg_loss:.4f}'})
        if writer is not None:
            
            writer.add_scalar('train/iter_loss', avg_loss, epoch * len(loader) + it)

    
    epoch_stats = {k: v / it for k, v in running.items()}
    return epoch_stats

@torch.no_grad()
def validate(model, loader, device, cfg, epoch, writer=None, logger=None):
    model.eval()
    running = {}
    it = 0
    pbar = tqdm(loader, desc=f"Valid  Epoch {epoch}", leave=False)
    for batch in pbar:
        it += 1
        for k in ['chain_1_graph', 'chain_2_graph', 'interface_graph']:
            if k in batch and hasattr(batch[k], 'to'):
                batch[k] = batch[k].to(device)
        
        loss_dict = model(batch)
        total_loss, scalar_dict = aggregate_loss_dict(loss_dict)
        for k, v in scalar_dict.items():
            running.setdefault(k, 0.0)
            running[k] += v

    epoch_stats = {k: v / it for k, v in running.items()} if it > 0 else {}
    if writer is not None and 'loss' in epoch_stats:
        writer.add_scalar('val/loss', epoch_stats['loss'], epoch)
    return epoch_stats


def plot_losses(train_losses: list, val_losses: list, save_path: str):
    plt.figure()
    x = list(range(1, len(train_losses) + 1))
    plt.plot(x, train_losses, label='train loss')
    if len(val_losses) > 0:
        plt.plot(x, val_losses, label='val loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def save_checkpoint(state: dict, checkpoint_dir: str, epoch: int, is_best=False):
    make_dirs(checkpoint_dir)
    ckpt_path = os.path.join(checkpoint_dir, f'ckpt_epoch_{epoch}.pt')
    torch.save(state, ckpt_path)
    if is_best:
        best_path = os.path.join(checkpoint_dir, 'best_ckpt.pt')
        torch.save(state, best_path)

def load_checkpoint_if_any(resume_path: str, model, optimizer=None, scheduler=None, scaler=None, device=torch.device('cpu')):
    if resume_path and os.path.exists(resume_path):
        ckpt = torch.load(resume_path, map_location=device)
        model.load_state_dict(ckpt['model_state'])
        if optimizer is not None and 'optim_state' in ckpt:
            optimizer.load_state_dict(ckpt['optim_state'])
        if scheduler is not None and 'sched_state' in ckpt:
            scheduler.load_state_dict(ckpt['sched_state'])
        if scaler is not None and 'scaler' in ckpt:
            scaler.load_state_dict(ckpt['scaler'])
        start_epoch = ckpt.get('epoch', 0) + 1
        best_val = ckpt.get('best_val', float('inf'))
        print(f"Loaded checkpoint '{resume_path}' (starting epoch {start_epoch})")
        return start_epoch, best_val
    return 1, float('inf')


def main(config_path):
    
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    checkpoint_dir = cfg['train'].get('checkpoint_dir', './checkpoints')
    log_dir = os.path.join(cfg['train'].get('log_dir', './logs'), timestamp)
    make_dirs(checkpoint_dir)
    make_dirs(log_dir)

    
    log_file = os.path.join(log_dir, 'train.log')
    logging.basicConfig(
        filename=log_file,
        level=logging.INFO,
        format='%(asctime)s %(levelname)s %(message)s'
    )
    logger = logging.getLogger()
    logger.addHandler(logging.StreamHandler())  

    logger.info(f"Loaded config from {config_path}")
    logger.info(cfg)

    
    if cfg['train']['device'] == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(cfg['train']['device'])
    logger.info(f"Using device: {device}")

    
    seed_everything(cfg['train'].get('seed', 42))

    
    train_files = sorted(glob.glob(os.path.join(cfg['data']['train_dir'], '*.pt')))
    val_files = sorted(glob.glob(os.path.join(cfg['data']['val_dir'], '*.pt')))

    logger.info(f"Found {len(train_files)} train files, {len(val_files)} val files")
    train_ds = PtDataset(train_files)
    val_ds = PtDataset(val_files)

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg['dataloader']['batch_size'],
        shuffle=True,
        num_workers=cfg['data'].get('num_workers', 4),
        pin_memory=cfg['data'].get('pin_memory', True),
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg['dataloader']['batch_size'],
        shuffle=False,
        num_workers=cfg['data'].get('num_workers', 4),
        pin_memory=cfg['data'].get('pin_memory', True),
        collate_fn=collate_fn
    )

    
    model = InterFaceDiff(cfg)  
    model.to(device)
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    
    opt_cfg = cfg['optim']
    if opt_cfg.get('name', 'AdamW') == 'AdamW':
        optimizer = AdamW(model.parameters(), lr=opt_cfg.get('lr', 1e-4),
                          weight_decay=opt_cfg.get('weight_decay', 0.0),
                          betas=tuple(opt_cfg.get('betas', [0.9, 0.999])))
    else:
        optimizer = SGD(model.parameters(), lr=opt_cfg.get('lr', 1e-3), momentum=0.9)

    
    sched_cfg = cfg.get('scheduler', {})
    if sched_cfg.get('name', '') == 'CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=sched_cfg.get('T_max', cfg['train']['epochs']),
                                     eta_min=sched_cfg.get('eta_min', 0.0))
    else:
        scheduler = StepLR(optimizer, step_size=sched_cfg.get('step_size', 30), gamma=sched_cfg.get('gamma', 0.1))

    
    scaler = GradScaler() if cfg['amp'].get('enabled', False) and device.type == 'cuda' else None

    
    writer = SummaryWriter(log_dir=log_dir)

    
    csv_path = os.path.join(log_dir, 'losses.csv')
    csv_file = open(csv_path, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['epoch', 'train_loss', 'val_loss'])

    
    start_epoch = 1
    best_val = float('inf')
    if cfg['misc'].get('resume_from', ""):
        start_epoch, best_val = load_checkpoint_if_any(cfg['misc']['resume_from'], model, optimizer, scheduler, scaler, device)

    
    train_losses = []
    val_losses = []
    for epoch in range(start_epoch, cfg['train']['epochs'] + 1):
        train_stats = train_one_epoch(model, train_loader, optimizer, scaler, device, cfg, epoch, writer, logger)
        train_loss = train_stats.get('loss', sum(train_stats.values()))
        train_losses.append(train_loss)

        
        if isinstance(scheduler, CosineAnnealingLR) or isinstance(scheduler, StepLR):
            scheduler.step()

        
        if (epoch % cfg['train'].get('eval_interval', 1)) == 0:
            val_stats = validate(model, val_loader, device, cfg, epoch, writer, logger)
            val_loss = val_stats.get('loss', sum(val_stats.values())) if val_stats else float('nan')
            val_losses.append(val_loss)
        else:
            val_loss = float('nan')
            val_losses.append(val_loss)

        
        writer.add_scalar('train/epoch_loss', train_loss, epoch)
        if not math.isnan(val_loss):
            writer.add_scalar('val/epoch_loss', val_loss, epoch)

        
        logger.info(f"Epoch {epoch:03d} TrainLoss: {train_loss:.6f} ValLoss: {val_loss:.6f}")

        
        csv_writer.writerow([epoch, train_loss, val_loss])
        csv_file.flush()

        
        curve_path = os.path.join(log_dir, 'loss_curve.png')
        plot_losses(train_losses, val_losses, curve_path)

        
        state = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optim_state': optimizer.state_dict(),
            'sched_state': scheduler.state_dict(),
            'best_val': best_val
        }
        if scaler is not None:
            state['scaler'] = scaler.state_dict()

        is_best = False
        if not math.isnan(val_loss) and val_loss < best_val:
            best_val = float(val_loss)
            is_best = True

        if cfg['misc'].get('save_best_only', True):
            if is_best:
                save_checkpoint(state, checkpoint_dir, epoch, is_best=True)
        else:
            save_checkpoint(state, checkpoint_dir, epoch, is_best=is_best)

        
        if epoch % cfg['train'].get('save_every_n_epochs', 10) == 0:
            save_checkpoint(state, checkpoint_dir, epoch, is_best=False)

    
    final_path = os.path.join(checkpoint_dir, 'final_model.pt')
    torch.save({
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optim_state': optimizer.state_dict(),
        'sched_state': scheduler.state_dict(),
        'best_val': best_val
    }, final_path)

    csv_file.close()
    writer.close()
    logger.info("Training finished. Final model saved to: " + final_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config.yaml', help='path to yaml config')
    args = parser.parse_args()
    main(args.config)

