import pprint
import os
import time
from core.data_utils.data_loader import load_dataset
from core.data_utils.data_loader import collect_cross_dataset_samples
from core.GNN.my_model import GAT, train, test, GATWithCrossAttention
from core.llm.prompt import OllamaPromptGenerator
from core.config import load_cfg
import numpy as np
import torch
import argparse
import random
import logging
torch.cuda.empty_cache()
torch.cuda.ipc_collect()



def setup_logger(log_path, mode='w'):
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if logger.hasHandlers():
        logger.handlers.clear()

    formatter = logging.Formatter('[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

    file_handler = logging.FileHandler(log_path, mode=mode)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)


def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    os.environ['PYTHONHASHSEED'] = str(seed)


def train_and_save_gat(cfg, data):
    if cfg.model.type == 'gat':
        ModelClass = GAT
    else:
        raise ValueError(f"Unsupported model type: {cfg.model.type}")
    gat_model = ModelClass(cfg, cfg.dataset.source_label_num, cfg.gnn.source_in_channels, data.num_nodes,
                           init_embedding=True,
                           is_classification=True)
    gat_model.to(cfg.training.device)
    optimizer = torch.optim.Adam(gat_model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay)
    data.edge_index = data.edge_index.to(torch.long)
    min_loss = 100
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(200):
        gat_model.train()
        out = gat_model(data.edge_index.to("cuda"))
        train_out = out[data.train_mask].to("cuda")
        train_target = data.y[data.train_mask].to("cuda")
        train_loss = criterion(train_out, train_target)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        gat_model.eval()
        out = gat_model(data.edge_index.to("cuda"))
        test_out = out[data.test_mask].to("cuda")
        test_target = data.y[data.test_mask].to("cuda")
        test_loss = criterion(test_out, test_target)
        print('epoch: {}, traind_avg_loss: {}, test_avg_loss: {}'.format(epoch, train_loss, test_loss))
        if test_loss < min_loss:
            min_loss = test_loss
            model_path = cfg.gnn.model_path
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(gat_model.state_dict(), model_path)
            print(f"GAT model saved to {model_path}")

def parse_args():
    parser = argparse.ArgumentParser(description='Training Script')

    parser.add_argument('--warmup', type=int, default=20)
    parser.add_argument('--train_task', type=str, default='node')
    parser.add_argument('--test_task', type=str, default='node')
    parser.add_argument('--node_type', type=str, default='gat')
    parser.add_argument('--node_embed_type', type=str, default='tfidf')
    parser.add_argument('--edge_type', type=str, default='gcn')
    parser.add_argument('--edge_embed_type', type=str, default='tfidf')
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--weight_decay', type=float, default=5e-8)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--reg', type=float, default=0.1)
    parser.add_argument('--seqlen', type=int, default=128)
    parser.add_argument('--temperature', type=float, default=0.07)
    parser.add_argument('--neg_k', type=int, default=1)
    parser.add_argument('--fixed_length', type=int, default=20)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--source_gnn_in_channels', type=int, default=1433)
    parser.add_argument('--target_gnn_in_channels', type=int, default=1433)
    parser.add_argument('--gnn_hidden_channels', type=int, default=512)
    parser.add_argument('--gnn_out_channels', type=int, default=384)
    parser.add_argument('--source_dataset_name', type=str, default='cora')
    parser.add_argument('--target_dataset_name', type=str, default='cora')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--source_label_num', type=int, default=7)
    parser.add_argument('--target_label_num', type=int, default=7)
    parser.add_argument('--source_gcn_in_channels', type=int, default=1433)
    parser.add_argument('--target_gcn_in_channels', type=int, default=1433)
    parser.add_argument('--source_save_dir', type=str, default='core/llm_llama3.2_response/cora')
    parser.add_argument('--target_save_dir', type=str, default='core/llm_llama3.2_response/cora')
    parser.add_argument('--drop_edge_ratio', type=float, default=0.0)
    parser.add_argument('--drop_node_ratio', type=float, default=0.0)
    parser.add_argument('--text_mask_ratio', type=float, default=0.0)
    parser.add_argument('--user_cross', action='store_true', help='Enable user cross')
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--shot_mode', type=str, default='zero-shot',
                        choices=['zero-shot', 'few-shot', 'full-supervised'],
                        help='Cross-domain training mode')

    args = parser.parse_args()
    return args


def main():
    cfg = load_cfg("config.yaml")
    args = parse_args()
    cfg.model.node_type = args.node_type
    cfg.model.node_embed_type = args.node_embed_type
    cfg.model.edge_type = args.edge_type
    cfg.model.edge_embed_type = args.edge_embed_type
    cfg.training.warmup = args.warmup
    cfg.training.lr = args.lr
    cfg.training.weight_decay = args.weight_decay
    cfg.training.batch_size = args.batch_size
    cfg.training.reg = args.reg
    cfg.training.seqlen = args.seqlen
    cfg.training.fixed_length = args.fixed_length
    cfg.training.epochs = args.epochs
    cfg.training.temperature = args.temperature
    cfg.training.neg_k = args.neg_k
    cfg.dataset.seed = args.seed
    cfg.dataset.drop_edge_ratio = args.drop_edge_ratio
    cfg.dataset.drop_node_ratio = args.drop_node_ratio
    cfg.dataset.text_mask_ratio = args.text_mask_ratio
    cfg.training.task = args.train_task
    cfg.testing.task = args.test_task
    cfg.gnn.source_in_channels = args.source_gnn_in_channels
    cfg.gnn.target_in_channels = args.target_gnn_in_channels
    cfg.gnn.hidden_channels = args.gnn_hidden_channels
    cfg.gnn.out_channels = args.gnn_out_channels
    cfg.gnn.num_layers = args.num_layers
    cfg.gcn.source_in_channels = args.source_gcn_in_channels
    cfg.gcn.target_in_channels = args.target_gcn_in_channels
    cfg.dataset.source_name = args.source_dataset_name
    cfg.dataset.target_name = args.target_dataset_name
    cfg.dataset.source_label_num = args.source_label_num
    cfg.dataset.target_label_num = args.target_label_num
    cfg.llm.source_save_dir = args.source_save_dir
    cfg.llm.target_save_dir = args.target_save_dir
    cfg.training.shot_mode = args.shot_mode

    cfg.training.user_cross = args.user_cross
    set_seed(cfg.dataset.seed)
    os.makedirs("log", exist_ok=True)
    log_file = os.path.join("log", f"{cfg.dataset.source_name}_{cfg.dataset.target_name}.log")
    setup_logger(log_file)
    logging.info(f"Starting training on dataset: {cfg.dataset.source_name}")
    logging.info("Complete training configuration:\n" + pprint.pformat(cfg))
    device = torch.device(cfg.training.device)
    source_data = load_dataset(cfg.dataset.source_name, use_text=cfg.dataset.use_text, seed=cfg.dataset.seed,
                               drop_edge_ratio=cfg.dataset.drop_edge_ratio, drop_node_ratio=cfg.dataset.drop_node_ratio,
                               text_mask_ratio=cfg.dataset.text_mask_ratio)
    target_data = load_dataset(cfg.dataset.target_name, use_text=cfg.dataset.use_text, seed=cfg.dataset.seed,
                               drop_edge_ratio=cfg.dataset.drop_edge_ratio,
                               drop_node_ratio=cfg.dataset.drop_node_ratio,
                               text_mask_ratio=cfg.dataset.text_mask_ratio,
                               shot_mode=cfg.training.shot_mode)

    cfg.gnn.source_vocab_size = source_data.num_nodes
    cfg.gnn.model_path = f'saved_model/{cfg.dataset.source_name}_GAT_neighbor'
    if not os.path.exists(cfg.gnn.model_path):
        logging.info(f"GAT model file not found at {cfg.gnn.model_path}. Training GAT model...")
        train_and_save_gat(cfg, source_data)
    gat_model = GAT(cfg, cfg.dataset.source_label_num, cfg.gnn.source_in_channels, cfg.gnn.source_vocab_size,
                    init_embedding=True,
                    is_classification=True).to(device)
    gat_model.load_state_dict(torch.load(cfg.gnn.model_path))
    gat_model.to(device)
    gat_model.eval()
    prompt_generator = OllamaPromptGenerator(cfg, gat_model)
    prompts = prompt_generator.generate_and_save_prompts(source_data, cfg.llm.source_save_dir)
    source_data.x_prompts = prompts

    cfg.gnn.target_vocab_size = target_data.num_nodes
    cfg.gnn.model_path = f'saved_model/{cfg.dataset.target_name}_GAT_neighbor'
    if not os.path.exists(cfg.gnn.model_path):
        logging.info(f"GAT model file not found at {cfg.gnn.model_path}. Training GAT model...")
        train_and_save_gat(cfg, target_data)
    gat_model = GAT(cfg, cfg.dataset.target_label_num, cfg.gnn.target_in_channels, cfg.gnn.target_vocab_size,
                    init_embedding=True,
                    is_classification=True).to(device)
    CrossModalModel = GATWithCrossAttention
    gat_model.load_state_dict(torch.load(cfg.gnn.model_path))
    gat_model.to(device)
    gat_model.eval()
    prompt_generator = OllamaPromptGenerator(cfg, gat_model)
    prompts = prompt_generator.generate_and_save_prompts(target_data, cfg.llm.target_save_dir)
    target_data.x_prompts = prompts
    cfg.gnn.vocab_size = source_data.num_nodes
    source_data = source_data.to(device)
    target_data = target_data.to(device)
    model = CrossModalModel(cfg)
    model.to(device)
    best_node_acc = 0
    best_edge_acc = 0
    wrong_node_id, wrong_label, wrong_pred = None, None, None
    best_node_test = {
        "acc": 0,
        "f1": 0,
        "loss": 0,
    }
    best_edge_test = {
        "auc": 0,
        "ap": 0,
        "acc": 0,
        "f1": 0,
        "loss": 0,
    }
    for epoch in range(cfg.training.epochs):
        if epoch < cfg.training.warmup:
            warmup = True
        else:
            warmup = False
        train_node_acc, train_edge_auc, train_node_f1, train_edge_ap, train_edge_acc, train_edge_f1, train_node_loss, train_edge_loss = train(
            model, source_data, cfg, warmup, source_data.num_nodes)
        if cfg.dataset.source_name != cfg.dataset.target_name:
            train(model, target_data, cfg, warmup, target_data.num_nodes, True)
        logging.info(
            "=====================================================================================================================")
        logging.info(f"Epoch: {epoch + 1}, "
                     f'TRAIN Node Loss: {train_node_loss:.4f}, '
                     f'Edge Loss: {train_edge_loss:.4f}, '
                     f'Node Accuracy: {train_node_acc:.4f}, '
                     f'Node F1: {train_node_f1:.4f}, '
                     f'Edge AUC: {train_edge_auc:.4f}, '
                     f'Edge ACC: {train_edge_acc:.4f},'
                     f'Edge F1: {train_edge_f1:.4f},'
                     f'Edge AP: {train_edge_ap:.4f}, '
                     )
        if cfg.dataset.source_name == cfg.dataset.target_name:
            val_node_acc, val_edge_auc, val_node_f1, val_edge_ap, val_edge_acc, val_edge_f1, val_node_loss, val_edge_loss, wrong_n, wrong_l, wrong_p = test(
                model, source_data, cfg,
                warmup,
                "val")
        else:
            val_node_acc, val_edge_auc, val_node_f1, val_edge_ap, val_edge_acc, val_edge_f1, val_node_loss, val_edge_loss, wrong_n, wrong_l, wrong_p = test(
                model, target_data, cfg,
                warmup,
                "val")
        logging.info(f"Epoch: {epoch + 1}, "
                     f'VAL Node Loss: {val_node_loss:.4f}, '
                     f'Edge Loss: {val_edge_loss:.4f}, '
                     f'Node Accuracy: {val_node_acc:.4f}, '
                     f'Node F1: {val_node_f1:.4f}, '
                     f'Edge AUC: {val_edge_auc:.4f}, '
                     f'Edge ACC: {val_edge_acc:.4f},'
                     f'Edge F1: {val_edge_f1:.4f},'
                     f'Edge AP: {val_edge_ap:.4f}, '
                     )
        if cfg.dataset.source_name == cfg.dataset.target_name:
            test_node_acc, test_edge_auc, test_node_f1, test_edge_ap, test_edge_acc, test_edge_f1, test_node_loss, test_edge_loss, _, _, _ = test(
                model, source_data, cfg,
                warmup,
                "test")
        else:
            test_node_acc, test_edge_auc, test_node_f1, test_edge_ap, test_edge_acc, test_edge_f1, test_node_loss, test_edge_loss, _, _, _ = test(
                model, target_data, cfg,
                warmup,
                "test")
        logging.info(f"Epoch: {epoch + 1}, "
                     f'TEST Node Loss: {test_node_loss:.4f}, '
                     f'Edge Loss: {test_edge_loss:.4f}, '
                     f'Node Accuracy: {test_node_acc:.4f}, '
                     f'Node F1: {test_node_f1:.4f}, '
                     f'Edge AUC: {test_edge_auc:.4f}, '
                     f'Edge ACC: {test_edge_acc:.4f},'
                     f'Edge F1: {val_edge_f1:.4f},'
                     f'Edge AP: {val_edge_ap:.4f}, '
                     )
        if val_node_acc >= best_node_acc:
            best_node_acc = val_node_acc
            best_node_test = {
                "acc": test_node_acc,
                "f1": test_node_f1
            }
            if wrong_n is not None:
                wrong_node_id, wrong_label, wrong_pred = wrong_n, wrong_l, wrong_p

        if val_edge_acc >= best_edge_acc:
            best_edge_acc = val_edge_acc
            best_edge_test = {
                "auc": test_edge_auc,
                "ap": test_edge_ap,
                "acc": test_edge_acc,
                "f1": test_edge_f1
            }
    logging.info("==== Best Test Metrics ====")
    logging.info(f"Node Accuracy  : {best_node_test['acc']:.4f}")
    logging.info(f"Node F1 Score  : {best_node_test['f1']:.4f}")
    logging.info(f"Edge AUC       : {best_edge_test['auc']:.4f}")
    logging.info(f"Edge AP Score  : {best_edge_test['ap']:.4f}")
    logging.info(f"Edge Accuracy  : {best_edge_test['acc']:.4f}")
    logging.info(f"Edge F1 Score  : {best_edge_test['f1']:.4f}")
    logging.info(f"wrong_node_id  : {wrong_node_id}")
    logging.info(f"wrong_label  : {wrong_label}")
    logging.info(f"wrong_pred  : {wrong_pred}")


if __name__ == '__main__':
    main()