import os
import csv
import json
import yaml
from argparse import ArgumentParser, Namespace
import torch
import pandas as pd
import datasets
from peft import PeftModel
import sys
import random
import numpy as np
from torch.utils.data import DataLoader, Subset

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from train import load_model_and_tokenizer
from data_modules.base_data import load_arxiv_train_dataset, custom_data_collator_arxiv as arxiv_data_collator, custom_data_collator_with_indices
from evaluate_methods import eval_mia_arxiv, eval_rouge, eval_ps
from data_modules.data_module import UnwatermarkedTextDataset
import utils

def load_trainable_model(model, path):
    checkpoint = torch.load(path)
    trainable_layers = list(checkpoint.keys())
    for name in trainable_layers:
        if checkpoint[name].shape[0] == 0:
            checkpoint.pop(name) 
            print(f'Discard {name} because of 0 parameters')
    trainable_layers = set(checkpoint.keys())
    all_layers = set(model.state_dict().keys())
    num_match_layers = len(trainable_layers.intersection(all_layers))
    print('Load trainable parameters for {}/{} layers'.format(num_match_layers, len(all_layers)))
    model.load_state_dict(checkpoint, strict=False)
    return

if __name__ == "__main__":
    parser = ArgumentParser() 
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--model_path", type=str, default="main_results/model.pt",
                        help="Path of the unlearned model")
    parser.add_argument('--data_config_path', type=str,
                        help='Path to dataset and split config')
    parser.add_argument('--model_config_path', type=str,
                        help='Path to model config')
    parser.add_argument('--output_dir', type=str, default='results/eval/',
                        help='Directory to save results and models')         
    parser.add_argument("--run_mia", action="store_true",
                        help="Whether to run mia evaluation")
    parser.add_argument("--run_rouge", action="store_true",
                        help="Whether to run rouge evaluation")
    parser.add_argument("--run_knowmem", action="store_true",
                        help="Whether to run knowmem evaluation")
    parser.add_argument("--run_ps", action="store_true",
                        help="Whether to run ps evaluation")
    parser.add_argument("--gen_path", type=str, default=None,
                        help="Directory to the duplicate forget set") 
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    utils.set_seed(args.seed)
    
    # load data config
    with open(args.data_config_path, 'r') as f:
        data_config = Namespace(**yaml.safe_load(f))
    print('data_config:', vars(data_config))
    data_config.seed = args.seed  # set seed
    
    if args.gen_path is not None:
        data_config.gen_path = args.gen_path
    
    # load model config
    with open(args.model_config_path, 'r') as f:
        config = Namespace(**yaml.safe_load(f))
    print('model_config:', vars(config))
    
    forget_ratio = data_config.forget_ratio
    calibration = False
    
    # create and load model and tokenizer
    base_model, tokenizer = load_model_and_tokenizer(config)
    model = PeftModel.from_pretrained(base_model, args.model_path)
    print(f'Loaded model from {args.model_path}')
    
    # load dataset used by the correpsonding unlearning algo
    train_data, forget_data, retain_data = load_arxiv_train_dataset(**vars(data_config))
    data_collator = arxiv_data_collator
    get_encoded_dataset = lambda data: UnwatermarkedTextDataset(
        data,
        tokenizer,
        max_length=config.max_seq_length,
    )
    train_data = get_encoded_dataset(train_data)
    forget_data = get_encoded_dataset(forget_data)
    retain_data = get_encoded_dataset(retain_data)
    print('train_data_len:', len(train_data))
    print('forget_data_len:', len(forget_data))
    print('retain_data_len:', len(retain_data))
    
    holdout_data = datasets.load_dataset(data_config.hf_dataset_name, "holdout")["full"]
    holdout_data = UnwatermarkedTextDataset(data=holdout_data, tokenizer=tokenizer, 
                                    max_length=config.max_seq_length)

    os.makedirs(args.output_dir, exist_ok=True)
    
    # Run Membership Inference Attack
    if args.run_mia:
        mia_auc, mia_log = eval_mia_arxiv.eval_mia(forget_data, retain_data, holdout_data, model, tokenizer, arxiv_data_collator)
        print("MIA AUC: ", mia_auc)
        with open(f"{args.output_dir}/mia.json", "w") as f:
            json.dump(mia_auc, f)
    
    
    
    # # Run ROUGE evaluation
    if args.run_rouge:
        save_folder = f"{args.output_dir}"
        os.makedirs(save_folder, exist_ok=True)
        aggregated_eval_rouge = {}
        eval_tasks = ["eval_rouge", "eval_rouge_forget"]
        
        target_size = len(forget_data)

        if len(retain_data) > target_size:
            print(f"Downsampling retain_data from {len(retain_data)} to {target_size}")
            indices = torch.randperm(len(retain_data))[:target_size]
            retain_data = torch.utils.data.Subset(retain_data, indices)
        
        datas = [retain_data, forget_data]
        for i, (data, eval_task) in enumerate(zip(datas, eval_tasks)):
            print(f'Working on eval task {eval_task}')
        
            dataloader = torch.utils.data.DataLoader(
                data, batch_size=config.eval_batch_size, collate_fn=custom_data_collator_with_indices
            )
            rouge = eval_rouge.eval_gen(model, tokenizer, dataloader)

            with open(f'{save_folder}/{eval_task}.json', "w") as f:
                json.dump(rouge, f, indent=4)
            aggregated_eval_rouge[f'{eval_task}.json'] = rouge
        
        aggregated_eval_log_filename = f"{args.output_dir}/eval_rouge_aggregated.json"
        with open(aggregated_eval_log_filename, "w") as f:
            json.dump(aggregated_eval_rouge, f, indent=4)
            
        ckpt_result = json.load(open(aggregated_eval_log_filename))

        aggregated_results = eval_rouge.aggregate_results(ckpt_result)

        save_file = f"{args.output_dir}/rouge.csv"
        print('saveing file to ', save_file)
        with open(save_file, 'w') as f:  # You will need 'wb' mode in Python 2.x
            w = csv.DictWriter(f, aggregated_results.keys())
            w.writeheader()
            w.writerow(aggregated_results)

    # PS Metric 
    if args.run_ps:
        print('=' * 10, 'Running PS Metric', '=' * 10)
        
        ps_results = {}
        ps_datasets = {
            'forget': forget_data,
            'retain': retain_data
        }

        for k, v in ps_datasets.items():
            if len(v) > 50:
                indices = np.random.choice(len(v), 50, replace=False)
                ps_datasets[k] = Subset(v, indices)

        for data_split_name, raw_data in ps_datasets.items():
            loader = DataLoader(raw_data, batch_size=50, shuffle=False)
            
            for ps_type in ['exact']:
                key = f"ps_{data_split_name}_{ps_type}"
                try:
                    score = eval_ps.get_ps_score_arxiv(model, tokenizer, loader, ps_type=ps_type)
                    ps_results[key] = score
                    print(f"PS Metric | {key}: {score:.4f}")
                except Exception as e:
                    print(f"Error calculating {key}: {e}")
                    ps_results[key] = 0.0

        path = os.path.join(args.output_dir, 'ps_metric.json')
        with open(path, 'w') as f:
            json.dump(ps_results, f, indent=4)
