import os
from pathlib import Path
import argparse
import numpy as np
import torch
import time
import sys
import torch.distributed as dist

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from SFT.anollm import AnoLLM
from SFT.train_anollm import get_run_name
from evaluate.data_utils import load_data, DATA_MAP, get_text_columns, get_max_length_dict

def get_args():
	parser = argparse.ArgumentParser()
	parser.add_argument("--dataset", type = str, default='wine', choices = [d.lower() for d in DATA_MAP.keys()],
					help="Name of datasets in the ODDS benchmark")
	parser.add_argument("--exp_dir", type = str, default=None)
	parser.add_argument("--setting", type = str, default='semi_supervised', choices = ['semi_supervised', 'unsupervised'], help="semi_supervised:an uncontaminated, unsupervised setting; unsupervised:a contaminated, unsupervised setting")
	
	#dataset hyperparameters
	parser.add_argument("--data_dir", type = str, default='data')
	parser.add_argument("--n_splits", type = int, default=5)
	parser.add_argument("--split_idx", type = int, default=None) # 0 to n_split-1
	# binning
	parser.add_argument("--binning", type = str, choices=['quantile', 'equal_width', 'language', 'none', 'standard'], default='standard')
	parser.add_argument("--n_buckets", type = int, default=10)
	parser.add_argument("--remove_feature_name", action = 'store_true')
	
	# model hyperparameters (for getting the model name)
	parser.add_argument("--model", type = str, choices = ['gpt2', 'distilgpt2', 'smol', 'smol-360', 'smol-1.7b'], default='smol')
	parser.add_argument("--lora", action='store_true', default=False)
	parser.add_argument("--lr", type = float, default=5e-5)
	parser.add_argument("--random_init", action='store_true', default=False)
	parser.add_argument("--no_random_permutation", action='store_true', default=False)
	
	#testing
	parser.add_argument("--batch_size", type = int, default=128) # per gpu
	parser.add_argument("--n_permutations", type = int, default=100) # per gpu
	args = parser.parse_args()
	
	if args.model == 'smol':
		args.model = 'HuggingFaceTB/SmolLM-135M'
	elif args.model == 'smol-360':
		args.model = 'HuggingFaceTB/SmolLM-360M'
	elif args.model == 'smol-1.7b':	
		args.model = 'HuggingFaceTB/SmolLM-1.7B'
	
	return args

def main():
	# Set device and distributed context robustly
	local_rank = int(os.environ.get("LOCAL_RANK", 0))
	distributed = dist.is_available() and dist.is_initialized()
	world_size = dist.get_world_size() if distributed else 1
	use_cuda = torch.cuda.is_available()
	if use_cuda:
		torch.cuda.set_device(local_rank)
	
	args = get_args()

	if args.exp_dir is None:
		args.exp_dir = Path('exp') / args.dataset / args.setting / "split{}".format(args.n_splits) / "split{}".format(args.split_idx)
	else:
		args.exp_dir = Path(args.exp_dir)
	
	if not os.path.exists(args.exp_dir):
		raise ValueError("Experiment directory {} does not exist".format(args.exp_dir))
		
	score_dir = args.exp_dir / 'scores'
	run_name = get_run_name(args)

	score_path = score_dir / "{}.npy".format(run_name)
	print("score_path:",  score_path)	
	rank = dist.get_rank() if distributed else 0
	if rank == 0:
		os.makedirs(score_dir, exist_ok = True)

	remainder = args.n_permutations % world_size
	
	X_train, X_test, y_train, y_test = load_data(args)
	
	if not os.path.exists(score_path):
		if (args.exp_dir / 'config.json').exists():
			model_path_hf = args.exp_dir
		else:
			model_dir = args.exp_dir / 'models'
			model_path_hf = model_dir / run_name

		if not model_path_hf.exists() or not (model_path_hf / 'config.json').exists():
			raise FileNotFoundError(f"Model not found: {model_path_hf}")

		efficient_finetuning = 'lora' if args.lora else ''
		max_length_dict = get_max_length_dict(args.dataset)
		text_columns = get_text_columns(args.dataset)
		model = AnoLLM(
			args.model,
			efficient_finetuning=efficient_finetuning,
			max_length_dict=max_length_dict,
			textual_columns=text_columns,
			no_random_permutation=args.no_random_permutation,
		)
		print(text_columns, max_length_dict)

		print(f"Loading HuggingFace format model from {model_path_hf}")
		model.load_from_state_dict(model_path_hf)

		device = torch.device(f"cuda:{local_rank}") if use_cuda else torch.device("cpu")
		model.model.to(device)

		if distributed and world_size > 1:
			if use_cuda:
				model.model = torch.nn.parallel.DistributedDataParallel(
					model.model, device_ids=[local_rank], output_device=local_rank
				)
			else:
				model.model = torch.nn.parallel.DistributedDataParallel(model.model)
		n_perm = int(args.n_permutations / world_size) 
		n_perm = n_perm + 1 if local_rank < remainder else n_perm

		start_time = time.time()    
		scores = model.decision_function(
			X_test,
			n_permutations=n_perm,
			batch_size=args.batch_size,
			device="cuda" if use_cuda else "cpu",
		)
		end_time = time.time()

		if distributed and world_size > 1:
			all_scores = [None for _ in range(world_size)]
			dist.all_gather_object(all_scores, scores)
		else:
			all_scores = [scores]

		if rank == 0:
			print("Inference time:", end_time - start_time)

			run_time_dir = args.exp_dir / "run_time" / "test"
			os.makedirs(run_time_dir, exist_ok = True)
			run_time_path = args.exp_dir / "run_time" / "test" / f"{run_name}.txt"

			with open(run_time_path, 'w') as f:
				f.write(str(end_time - start_time))

			combined_scores = np.concatenate(all_scores, axis = 1) if len(all_scores) > 1 else all_scores[0]
			mean_scores = np.mean(combined_scores, axis = 1)
			np.save(score_path, mean_scores)
			raw_score_path =  score_dir / f"raw_{run_name}.npy"
			np.save(raw_score_path, combined_scores)
			print(f"Scores saved to {score_path}")

			y_test_flat = y_test.flatten() if len(y_test.shape) > 1 else y_test
			
			print(f"Mean score (normal): {np.mean(mean_scores[y_test_flat == 0]):.4f}")
			print(f"Mean score (anomaly): {np.mean(mean_scores[y_test_flat == 1]):.4f}")
			
			try:
				sys.path.insert(0, str(Path(__file__).parent.parent / 'evaluate'))
				from evaluate.compute_metrics import compute_detection_metrics as tabular_metrics
				auc_roc, auc_pr, f1, precision, recall = tabular_metrics(y_test_flat, mean_scores)
				print("-" * 100)
				result_line = f"{run_name:30s}: AUC-ROC: {auc_roc:.4f} ( 1), AUC-PR: {auc_pr:.4f} ( 1), F1: {f1:.4f} ( 1), P: {precision:.4f} ( 1), R: {recall:.4f} ( 1)"
				print(result_line)

				exp_root = args.exp_dir.parents[1] if len(args.exp_dir.parents) >= 2 else args.exp_dir.parent
				evaluate_file = exp_root / 'evaluation_dispat.txt'
				print(f"Saving evaluation metrics to: {evaluate_file}")

				with open(evaluate_file, 'a') as f:
					f.write(result_line + '\n')

				print(f"Metrics saved to {evaluate_file}")
			except Exception as e:
				print(f"[WARN] Could not import tabular_metrics: {e}")
	
	if distributed:
		dist.destroy_process_group()
	
if __name__ == '__main__':
	world_size_env = int(os.environ.get("WORLD_SIZE", "1"))
	if world_size_env > 1:
		backend = "nccl" if torch.cuda.is_available() else "gloo"
		dist.init_process_group(backend=backend)
	main()