import os
import sys
import json
import torch
import networkx as nx
import random
import warnings
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from PIL import Image
from utils import calculate_distances, convert_to_edges, calculate_center_points

SIZE = 560
BASE_DIR = os.path.dirname(__file__)
DATASET = ""
DATASET_ROOT = os.path.join(BASE_DIR, 'dataset', DATASET)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_BIDIRECTIONAL_MATCHING = True  

warnings.filterwarnings('ignore', message='xFormers is not available')

generate_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'feature_matching'))
segment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'segmenter'))
sys.path.append(segment_path)
sys.path.append(generate_path)

from feature_matching.generate_points import generate, loading_dino
from segmenter.segment import loading_seg, seg_main
from medpy.metric.binary import dc

def calculate_dice_coefficient(pred_mask, gt_mask):
    pred_mask = pred_mask.astype(np.bool_)
    gt_mask = gt_mask.astype(np.bool_)
    return dc(pred_mask, gt_mask)

def save_agent_results(agents, current_episode, total_episodes, episode_results, output_path, prefix):
    episode_dir = os.path.join(output_path, 'episode_results')
    os.makedirs(episode_dir, exist_ok=True)
    
    info_path = os.path.join(episode_dir, f'training_log.txt')
    with open(info_path, "a") as f:
        f.write(f"\nEpisode Progress: {current_episode + 1}/{total_episodes}\n")
        f.write(f"Data Prefix: {prefix}\n")
        for i, agent in enumerate(agents):
            f.write(f"Agent {i} - Epsilon: {agent.epsilon:.4f}, "
                   f"Reward: {episode_results[i]['total_reward']:.4f}\n")
        f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("-" * 50 + "\n")

def create_graph_environment(features, positive_indices, negative_indices, image_size, device, max_steps, pos_weights=None, neg_weights=None):
    distances = calculate_distances(features, positive_indices, negative_indices, image_size, device)
    feature_pos_distances, feature_cross_distances, physical_pos_distances, physical_neg_distances, physical_cross_distances = distances
    
    edges = [
        convert_to_edges(positive_indices, positive_indices, feature_pos_distances),
        convert_to_edges(positive_indices, positive_indices, physical_pos_distances),
        convert_to_edges(negative_indices, negative_indices, physical_neg_distances),
        convert_to_edges(positive_indices, negative_indices, feature_cross_distances),
        convert_to_edges(positive_indices, negative_indices, physical_cross_distances)
    ]
    
    G = nx.MultiGraph()
    
    pos_indices = positive_indices.cpu().numpy()
    neg_indices = negative_indices.cpu().numpy()
    
    for idx in pos_indices:
        weight = pos_weights.get(idx, 1.0) if pos_weights else 1.0
        G.add_node(idx, category='pos', weight=weight)
        
    for idx in neg_indices:
        weight = neg_weights.get(idx, 1.0) if neg_weights else 1.0
        G.add_node(idx, category='neg', weight=weight)
    
    for edge_list, weight_type in zip(edges, ['feature_pos', 'physical_pos', 'physical_neg', 'feature_cross', 'physical_cross']):
        G.add_weighted_edges_from(edge_list, weight=weight_type)
    

def main():
    episodes = 1000
    max_steps = 100
    
    current_time = datetime.now().strftime("%Y%m%d_%H%M")
    output_path = os.path.join(os.path.dirname(__file__), 'train', current_time)
    os.makedirs(output_path, exist_ok=True)
    
    classes = os.listdir(DATASET_ROOT)
    classes = [c for c in classes if os.path.isdir(os.path.join(DATASET_ROOT, c))]
    
    print(f"Found {len(classes)} categories in { DATASET } dataset")
    
    valid_category_paths = []
    for category in classes:
        category_path = os.path.join(DATASET_ROOT, category)
        
        required_dirs = ['reference_images', 'reference_masks', 'target_images', 'target_masks']
        if all(os.path.exists(os.path.join(category_path, dir_name)) for dir_name in required_dirs):
            valid_category_paths.append(category_path)
        else:
            print(f"Warning: Skipping category {category} - missing required directories")
    
    if not valid_category_paths:
        print("Error: No valid categories found!")
        return
    
    config = {
        'agent_types': ['feature', 'physical'],
        'episodes': episodes,
        'max_steps': max_steps,
        'dataset': DATASET,
        'valid_categories': [os.path.basename(path) for path in valid_category_paths],
        'image_size': SIZE,
    }
    with open(os.path.join(output_path, 'config.json'), 'w') as f:
        json.dump(config, f)
    
if __name__ == "__main__":
    main()
