import argparse
import copy
import torch
from exp.experiment_explanation_v2 import Explanation_Experiment
import random
import numpy as np

def get_args():
    parser = argparse.ArgumentParser()

    # SBM and InterpGN model hyperparameters
    parser.add_argument("--data", type=str, default="UEA", choices=['UEA', 'Monash'])
    parser.add_argument("--data_root", type=str, default="./data/UEA_multivariate")
    parser.add_argument("--model", type=str, default='SBM', choices=['SBM', 'LTS', 'InterpGN', 'DNN'])
    parser.add_argument("--dnn_type", type=str, default='FCN', choices=['FCN', 'Transformer', 'TimesNet', 'PatchTST', 'ResNet'])
    parser.add_argument("--dataset", type=str, default="BasicMotions")
    parser.add_argument("--lambda_reg", type=float, default=0.1)
    parser.add_argument("--lambda_div", type=float, default=0.1)
    parser.add_argument("--epsilon", type=float, default=1.)
    parser.add_argument("--num_shapelet", type=int, default=10)
    parser.add_argument("--gating_value", type=float, default=None)
    parser.add_argument("--pos_weight", action="store_true")
    parser.add_argument("--sbm_cls", type=str, default='linear')
    parser.add_argument("--distance_func", type=str, default='euclidean')
    parser.add_argument("--beta_schedule", type=str, default='constant')
    parser.add_argument("--memory_efficient", action="store_true")

    # Experiment config
    parser.add_argument("--lr", type=float, default=5e-3)
    parser.add_argument("--lr_decay", action="store_true")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--gradient_clip", type=float, default=0)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument('--log_interval', type=int, default=20)
    parser.add_argument("--min_epochs", type=int, default=0)
    parser.add_argument("--train_epochs", type=int, default=500)
    parser.add_argument("--num_workers", type=int, default=0)
    parser.add_argument("--patience", type=int, default=50)
    parser.add_argument("--multi_gpu", action='store_true')
    parser.add_argument("--test_only", action='store_true')
    parser.add_argument("--seed", type=int, default=-1)
    parser.add_argument("--amp", action='store_false', default=True)

    # basic config
    parser.add_argument('--task_name', type=str, default='explanation',
                        help='task name, options:[explanation]')
    parser.add_argument('--model_id', type=str, default='test', help='model id')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    
    # DNN model configs
    parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
    parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=7, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=2048, help='dimension of ff layers')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--distil', action='store_false',
                        help='whether to use distilling in encoder, using this argument means not using distilling',
                        default=True)
    parser.add_argument('--dropout', type=float, default=0, help='dropout')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')

    # TimesNet specific
    parser.add_argument('--label_len', type=int, default=48, help='start token length')
    parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
    parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
    parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)

    args = parser.parse_args()
    args.root_path = f"{args.data_root}/{args.dataset}"
    args.is_training = False
    return args

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
    args = get_args()
    exp_explain = Explanation_Experiment

    random_seeds = [0, 42, 1234, 8237, 2023] if args.seed == -1 else [copy.deepcopy(args.seed)]

    for i, seed in enumerate(random_seeds):
        set_seed(seed)
        args.seed = seed

        print(f"{'=' * 5} Experiment {i} {'=' * 5} ", flush=True)
        experiment = exp_explain(args=args)
        experiment.print_args()
        print()
        
        # experiment.model.load_state_dict(torch.load(f"{experiment.checkpoint_dir}/checkpoint.pth"))
        print(f"{'=' * 5} Test {'=' * 5} ")
        results = experiment.test(
            save_csv=True,
            result_dir=f"./result/{args.model}"
        )
        print(f"Test | Loss {results.loss:.4f}")
        print(f"Test | Accuracy {results.accuracy:.4f}")
        print()
        torch.cuda.empty_cache()

        print(f"{'=' * 5} Training Utility {'=' * 5} ")
        results = experiment.get_train_utility(
            save_csv=True,
            result_dir=f"./result/{args.model}"
        )
        print(f"Test | Loss {results.loss:.4f}")
        print(f"Test | Accuracy {results.accuracy:.4f}")
        print()
        torch.cuda.empty_cache()

        experiment.model.eval() # Testing mode
        
        # ======================== Explanation Generation and Evaluation ========================
        
        # 1. Define the methods you want to evaluate
        # explanation_methods = ['Saliency', 'InputXGradient', 'IG', 'SegIG_8', 'SegIG_16', 'SegIG_32']
        # explanation_methods = ['IG', 'SegIG_8', 'SegIG_16', 'SegIG_32']
        # explanation_methods = ['IG']
        explanation_methods = ['InputXGradient']
        # explanation_methods = ['LIME']
        
        # 2. Loop through each method
        for method in explanation_methods:
            print(f"\n{'#' * 10} STARTING METHOD: {method} {'#' * 10}")
            
            # --- Step A: Generate the Attribution Maps ---
            # This saves the .npy file (e.g., ...-test_IG.npy)
            experiment.time_explanation_with_expert_name(
                explanation_name=method,
                target_set='test'
            )

            # experiment.explanation_with_expert_name_withnoise(
            #     explanation_name=method,
            #     target_set='test'
            # )
            
            # --- Step B: Evaluate Faithfulness (Accuracy-based) ---
            # This saves the _BottomUp_Mean_Faithfulness.csv
            # print(f"Evaluating Accuracy Faithfulness for {method}...")
            # experiment.evaluate_faithfulness_bottom_up(
            #     target_set='test',
            #     explanation_method=method,
            #     metric='accuracy'
            # )
            # experiment.evaluate_faithfulness_bottom_up(
            #     target_set='test',
            #     explanation_method=method,
            #     metric='roc_auc'
            # )
            # experiment.robustness_val(
            #     target_set='test',
            #     explanation_method=method,
            #     metric='roc_auc'
            # )

            # print(f"{'#' * 10} COMPLETED {method} {'#' * 10}\n")
            torch.cuda.empty_cache()
        # exit()

        # Final Cleanup for the seed loop
        # 3. Trigger the Ensemble selection
        # explanation_methods = ['LIME', 'KernelSHAP', 'IG', 'Saliency', 'InputXGradient', 'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.15', 'KeystoneIG_0.2', 'KeystoneIG_0.25']
        # # explanation_methods = ['IG', 'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.2']
        # # explanation_methods = ['IG', 'SegIG_8', 'SegIG_16', 'SegIG_32']
        # print("\n>>> Selecting Best-of-Ensemble explanations...")
        # experiment.select_best_instance_explanation(target_set='test', methods=explanation_methods)
        # print(f"Evaluating Accuracy Faithfulness for BestEnsemble...")
        # faithfulness_results = experiment.evaluate_faithfulness_bottom_up(
        #         target_set='test',
        #         explanation_method='BestEnsemble',
        #         metric='accuracy'
        #     )
        # faithfulness_results = experiment.evaluate_faithfulness_bottom_up(
        #         target_set='test',
        #         explanation_method='BestEnsemble',
        #         metric='roc_auc'
        #     )
        # torch.cuda.empty_cache()               
