import torch
import argparse
import os
import pandas as pd
import numpy as np
import random
from models import TopoDistill
from utils.data_loader import create_manifold_data_loaders
from utils.train import TwoStage_Trainer
from causal_inference import build_pairwise_convergence_matrix
from evaluate import evaluation_causal

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    os.environ["PYTHONHASHSEED"] = str(seed)

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--device', type=str, default='cuda:0', help='device')
    parser.add_argument('--window_size', type=int, default=3, help='window_size')  
    parser.add_argument('--window_step', type=int, default=1, help='downsample step in sliding window')
    parser.add_argument('--batch_time_steps', type=int, default=100, help='batch_size, means number of time steps in a batch') 
    parser.add_argument('--stride', type=int, default=1, help='stride between windows in a batch')
    parser.add_argument('--normalize', type=bool, default=True, help='whether to normalize the data')
    parser.add_argument('--global_win_last_n', type=int, default=1, help='number of last time steps to use for global context window')
    parser.add_argument('--global_horizon', type=int, default=1, help='global prediction horizon')
    parser.add_argument('--model_save_path', type=str, default='./model.pth', help='path to save the trained model')
    parser.add_argument('--need_train', type=bool, default=True, help='whether to train the model')
    parser.add_argument('--data_file', type=str, default="./data.csv", help='data_file path')
    parser.add_argument('--gt_file', type=str, default="./cm.csv", help='ground truth adjacency matrix file path')
    parser.add_argument('--embedding_dim', type=int, default=24, help='embedding_dim')   
    parser.add_argument('--global_embedding_dim', type=int, default=24, help='embedding_dim') 
    parser.add_argument('--tcn_n_layers', type=int, default=3, help='tcn_n_layers')
    parser.add_argument('--tcn_ch', type=int, default=64, help='tcn_ch')
    parser.add_argument('--tcn_ksize', type=int, default=3, help='tcn_ksize')
    # training hyperparameters
    parser.add_argument('--contrastive_coeff', type=float, default=0.3, help='contrastive_coeff') 
    parser.add_argument('--topology_coeff', type=float, default=0.7, help='topology_coeff') 
    
    parser.add_argument('--train_global_lr', type=float, default=5e-5, help='train_global_lr') 
    parser.add_argument('--train_individual_lr', type=float, default=1e-5, help='train_individual_lr')
    parser.add_argument('--weight_decay_global', type=float, default=1e-3, help='weight_decay_global')
    parser.add_argument('--weight_decay_individual', type=float, default=5e-5, help='weight_decay_individual')  
    # epochs
    parser.add_argument('--train_global_epoch', type=int, default=30, help='train_global_epoch')  
    parser.add_argument('--train_individual_epoch', type=int, default=100, help='train_individual_epoch')
    # contrastive learning params
    parser.add_argument('--pos_time_threshold', type=int, default=5, help='pos_time_threshold') 
    parser.add_argument('--neg_time_threshold', type=int, default=25, help='neg_time_threshold')
    parser.add_argument('--contrastive_temp', type=float, default=0.3, help='contrastive_temp')
    # CCM params 
    parser.add_argument('--k_ccm', type=int, default=3, help='k_ccm')
    parser.add_argument('--n_lib_points', type=int, default=10, help='n_lib_points') 
    parser.add_argument('--use_local_linear', type=bool, default=True, help='use_local_linear')

    args = parser.parse_args()
    
    return args

if __name__ == "__main__":
    
    SEED = 2026
    set_seed(SEED)
    
    all_results = []
    for times in range(1):
    
        args = parse_args()
        
        device = torch.device(args.device)

        print(f"Using device: {device}")

        train_loader, test_loader, dataset_info = create_manifold_data_loaders(
            csv_file=args.data_file,
            window_size=args.window_size,
            window_step=args.window_step,
            batch_time_steps=args.batch_time_steps,
            stride=args.stride,
            batches_per_epoch=args.batches_per_epoch,
            global_win_last_n=args.global_win_last_n,
            global_horizon=args.global_horizon
        )
        
        num_sequences = dataset_info['num_series']
        window_size = dataset_info['window_size']
        num_series = dataset_info['num_series']
        
        model = TopoDistill(
            num_sequences=num_sequences,
            window_size=window_size,
            embedding_dim=args.embedding_dim,
            global_embedding_dim = args.global_embedding_dim,
            adapter_dim=args.adapter_dim,
            TCN_n_layers=args.tcn_n_layers,
            TCN_ch=args.tcn_ch,
            TCN_ksize=args.tcn_ksize,
            global_win_last_n=args.global_win_last_n
        ).to(device)
        
        if args.need_train:
            print("Starting training...")
        
            trainer = TwoStage_Trainer(model, train_loader, test_loader, args, alignment_coeff=args.alignment_coeff, 
                                    contrastive_coeff=args.contrastive_coeff, reconstruction_coeff = args.reconstruction_coeff, topology_coeff = args.topology_coeff,
                                    train_global_lr=args.train_global_lr, 
                                    train_individual_lr=args.train_individual_lr)
            trainer.train_global_phase(num_epochs=args.train_global_epoch)
            trainer.train_individual_phase(num_epochs=args.train_individual_epoch)
            
            torch.save(model.state_dict(), args.model_save_path)
            print(f"Model saved to {args.model_save_path}")
        
        if os.path.exists(args.model_save_path):
            state = torch.load(args.model_save_path, map_location=device)
            model.load_state_dict(state)
        else:
            raise FileNotFoundError(f"Checkpoint not found: {args.model_save_path}")
        
        model.set_inference_mode(True)  
        model.eval()
        
        print("Building pairwise convergence matrix...")
        S_mat, col_order = build_pairwise_convergence_matrix(model, test_loader, args, device=device, k_ccm=args.k_ccm, n_lib_points=args.n_lib_points)
        print(S_mat)
        
        print("Evaluating against ground truth...")
        metrics = evaluation_causal(
            gt_path=args.gt_file,
            S_mat=S_mat,
            col_order=col_order
        )
        print("Evaluation Metrics:")
        for key, value in metrics.items():
            print(f"{key}: {value}")
        
        all_results.append(metrics)

        df = pd.DataFrame(all_results)
        df.to_csv("/result.csv", index=True)