import os
import random
import numpy as np
import torch
import wandb
import pickle
from torch.utils.data import DataLoader

from data.loaders.fakenews import FakeNewsTwitterDataset
from model.utils import SymmetricStabilizedBCEWithLogitsLoss
from config_fakenews import get_config
from model.GMANFakeNews import GMAN
from data.loaders.utils import distance_collate_fn_fakenews
from model.set_trainer import train_epoch, test_epoch

def main():
    # Set random seeds.

    # Get experiment configuration.
    exp_config = get_config()
    print(exp_config)
    # Set random seeds.
    SEED = exp_config.seed
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    split_path = "FakeNewsData/splits/split.pkl"
    with open(split_path, 'rb') as f:
        data_split = pickle.load(f)

    train_files = data_split["train_files"]
    val_files = data_split["val_files"]
    test_files = data_split["test_files"]

    print("Loaded pre-computed splits:")
    print(f"Number of training samples: {len(train_files)}")
    print(f"Number of validation samples: {len(val_files)}")
    print(f"Number of test samples: {len(test_files)}")

    # print("Weighted_loss:", pos_weight)    
    # Initialize one-hot embedder and datasets.
    # one_hot_embedder = OneHotEmbedder(input_dim=exp_config.num_biom, output_dim=exp_config.num_biom_embed)
    
    processed_path='FakeNewsData/data/gossipcop_graphs.pt'
    data = torch.load(processed_path)

    train_dataset = FakeNewsTwitterDataset(data=data, roots=train_files)
    val_dataset = FakeNewsTwitterDataset(data=data, roots=val_files)
    test_dataset = FakeNewsTwitterDataset(data=data, roots=test_files)
    
    # Create DataLoaders. For training, enable shuffling.
    train_loader = DataLoader(
        train_dataset,
        batch_size=exp_config.batch_size,
        shuffle=True,
        collate_fn=distance_collate_fn_fakenews,
        num_workers=16,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=4
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=exp_config.batch_size,
        shuffle=False,
        collate_fn=distance_collate_fn_fakenews,
        num_workers=16,
        pin_memory=True,
        persistent_workers=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=exp_config.batch_size,
        shuffle=False,
        collate_fn=distance_collate_fn_fakenews,
        num_workers=16,
        pin_memory=True,
        persistent_workers=True
    )
    
    unique_run_name = f"layers-{exp_config['n_layers']}_hidden-{exp_config['hidden_channels']}_lr-{exp_config['lr']}_dropout-{exp_config['dropout']}_seed-{exp_config['seed']}"
    
    # Initialize wandb if enabled.
    if exp_config.wandb:
        config_dict = exp_config.copy()
        config_dict['device'] = torch.cuda.get_device_name() if torch.cuda.is_available() else 'cpu'
        config_dict['model'] = exp_config.__class__.__name__
        wandb.init(
            project="DeepSetFakeNewsDataset",
            config=config_dict,
            entity='setgnan',
            settings=wandb.Settings(start_method='thread'),
            name=unique_run_name,
            reinit=True
        )

    model = GMAN(
        feature_groups=exp_config.feature_groups,
        out_channels=exp_config.out_channels,
        is_graph_task=exp_config.is_graph_task,
        batch_size=exp_config.batch_size,
        n_layers=exp_config.n_layers,
        hidden_channels=exp_config.hidden_channels,
        dropout=exp_config.dropout,
        device=device,
        normalize_rho=exp_config.normalize_rho,
        biomarker_groups=exp_config.biomarker_groups
    ).to(device)

   
    optimizer = torch.optim.Adam(model.parameters(), lr=exp_config.lr, weight_decay=exp_config.wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-8
    )
   
    last_epoch = 0
    
    print(f"Started training on {device}")
    print(f"Running experiment: {exp_config.exp_name}")
    
    best_val_loss = float('inf')
    best_val_auc = float('-inf')
    best_val_auprc = float('-inf')

    for epoch in range(exp_config.epochs):
        loss_fn = SymmetricStabilizedBCEWithLogitsLoss()
        
        train_loss, train_acc, train_auc, train_auprc  = train_epoch(
            epoch=last_epoch + epoch,
            model=model,
            dloader=train_loader,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            writer=None
        )
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Train Accuracy={train_acc:.4f}, Train AUC={train_auc:.4f}, Train AUPRC={train_auprc:.4f}")
        
        val_loss, val_acc, val_auc, val_auprc = test_epoch(
            epoch=last_epoch + epoch,
            model=model,
            dloader=val_loader,
            loss_fn=loss_fn,
            device=device,
            writer=None
        )
        print(f"Epoch {epoch}: Val Loss={val_loss:.4f}, Val Accuracy={val_acc:.4f}, Val AUC={val_auc:.4f}, Val AUPRC={val_auprc:.4f}")

        scheduler.step(val_loss)
        
        test_loss, test_acc, test_auc, test_auprc = test_epoch(
            epoch=last_epoch + epoch,
            model=model,
            dloader=test_loader,
            loss_fn=loss_fn,
            device=device,
            writer=None
        )
        print(f"Epoch {epoch}: Test Loss={test_loss:.4f}, Test Accuracy={test_acc:.4f}, Test AUC={test_auc:.4f}, Test AUPRC={test_auprc:.4f}")
        
        if exp_config.wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "train_auc": train_auc,
                "train_auprc": train_auprc,
                "val_loss": val_loss,
                "val_acc": val_acc,
                "val_auc": val_auc,
                "val_auprc": val_auprc,
                "test_loss": test_loss,
                "test_acc": test_acc,
                "test_auc": test_auc,
                "test_auprc": test_auprc,
            })
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_dir = f"{exp_config.model_checkpoints_dir}/{unique_run_name}"
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_loss.pth')
            torch.save({
                'epoch': last_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, checkpoint_name)
            if exp_config.wandb:
                wandb.save(checkpoint_name)

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            checkpoint_dir = f"{exp_config.model_checkpoints_dir}/{unique_run_name}"
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_auc.pth')
            torch.save({
                'epoch': last_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, checkpoint_name)
            if exp_config.wandb:
                wandb.save(checkpoint_name)

        if val_auprc > best_val_auprc:
            best_val_auprc = val_auprc
            checkpoint_dir = f"{exp_config.model_checkpoints_dir}/{unique_run_name}"
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_name = os.path.join(checkpoint_dir, f'best_params_by_val_auprc.pth')
            torch.save({
                'epoch': last_epoch + epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict()
            }, checkpoint_name)
            if exp_config.wandb:
                wandb.save(checkpoint_name)
        
        checkpoint_name = os.path.join(checkpoint_dir, f'last_epoch.pth')
        torch.save({
            'epoch': last_epoch + epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, checkpoint_name)
        if exp_config.wandb:
            wandb.save(checkpoint_name)

    test_loss, test_acc, test_auc, test_auprc = test_epoch(
        epoch=last_epoch + exp_config.epochs,
        model=model,
        dloader=test_loader,
        loss_fn=loss_fn,
        device=device,
        writer=None
    )
    print(f"Final Test Loss={test_loss:.4f}, Final Test Accuracy={test_acc:.4f}, Final Test AUC={test_auc:.4f}, Final Test AUPRC={test_auprc:.4f}")
    if exp_config.wandb:
        wandb.log({
            "final_test_loss": test_loss,
            "final_test_acc": test_acc,
            "final_test_auc": test_auc,
            "final_test_auprc": test_auprc
        })
        wandb.finish()

if __name__ == '__main__':
    main()