import argparse
import os
import pickle
import random
import torch
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from utils import set_seed, make_dirs, get_experiment_name
from data import load_data
from models import SupGCLModel
from samplers import TeacherSampler, apply_virtual_knockdown
from train import train_epoch, test_epoch

# ------------- 1. Data Path Settings -------------
DEFAULT_TCGA = 'path/to/tcga_graphs.pkl'
DEFAULT_LINCS = 'path/to/LINCS_KD_graphs.pkl'
DEFAULT_META1 = '~/SupGCL/data/meta_data/Breast/LINCS_KD_graphs_metadata.pkl'          # or 'Lung', 'Colorectal'
DEFAULT_META2 = '~/SupGCL/data/meta_data/Breast/LINCS_sampleID_KDgene_metadata.pkl'    # or 'Lung', 'Colorectal'

# ------------- 2. Argument Definitions -------------
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--tcga_path', type=str, default=DEFAULT_TCGA)
    parser.add_argument('--lincs_graphs', type=str, default=DEFAULT_LINCS)
    parser.add_argument('--lincs_meta_graphs', type=str, default=DEFAULT_META1)
    parser.add_argument('--lincs_meta_kd', type=str, default=DEFAULT_META2)
    parser.add_argument('--split', type=float, default=0.8)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int, default=3000)
    parser.add_argument('--tau_nce',type=float,default=0.25)
    parser.add_argument('--tau_aug',type=float,default=0.25)
    parser.add_argument('--hid', type=int, default=64)
    parser.add_argument('--out', type=int, default=64)
    parser.add_argument('--proj_out', type=int, default=64)
    parser.add_argument('--subsample_size', type=int, default=8)
    parser.add_argument('--seed', type=int, default=42)
    return parser.parse_args()

# ------------- 3. Main Process -------------
if __name__=='__main__':
    args = parse_args()
    # Set seed and device
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load data
    train_loader, test_loader = load_data(args.tcga_path, args.split, args.batch_size)
    # Load LINCS data
    with open(args.lincs_graphs, 'rb') as f:
        lincs_graphs = pickle.load(f)
    with open(args.lincs_meta_graphs, 'rb') as f:
        LINCS_KD_graphs_sampleID = pickle.load(f)
    with open(args.lincs_meta_kd, 'rb') as f:
        knockdown_metadata = pickle.load(f)
    # kd_gene -> sample_id list
    kd_gene_to_sample_ids = knockdown_metadata.groupby('kd_gene')['sample_id'].apply(list).to_dict()

    # Set experiment directories
    exp = get_experiment_name(args)
    base = os.getcwd()
    run_dir = os.path.join(base, 'runs',        exp)
    ckpt_dir= os.path.join(base, 'checkpoints', exp)
    make_dirs(run_dir); make_dirs(ckpt_dir)
    writer = SummaryWriter(run_dir)

    # Initialize model, sampler, and optimizer
    sample_graph = train_loader.dataset[0]
    in_c = sample_graph.x.size(1)
    model = SupGCLModel(in_c=in_c, hid=args.hid, out=args.out, proj_out=args.proj_out).to(device)
    sampler = TeacherSampler(
    lincs_graphs=lincs_graphs,
    sample_ids_list=LINCS_KD_graphs_sampleID,
    kd_meta=kd_gene_to_sample_ids,
    model=model,
    device=device,
    subsample_size=args.subsample_size
)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    # Training loop
    for epoch in range(1, args.epochs+1):
        info, aug, tot = train_epoch(model, train_loader, sampler, optimizer, args, device)
        tinfo, taug, ttot = test_epoch(model, test_loader, sampler, args, device)
        writer.add_scalars('Loss',      {'info':info,  'aug':aug,  'total':tot},  epoch)
        writer.add_scalars('Loss_test', {'info':tinfo, 'aug':taug, 'total':ttot}, epoch)
        print(f"Epoch {epoch:04d} | Train INFO={info:.4f}, AUG={aug:.4f}, TOT={tot:.4f}"
              f" | Test INFO={tinfo:.4f}, AUG={taug:.4f}, TOT={ttot:.4f}")
        if epoch % 100 == 0:
            torch.save({'model':model.state_dict(), 'opt':optimizer.state_dict()},
                       os.path.join(ckpt_dir, f'ep{epoch}.pt'))
    writer.close()