#!/usr/bin/env python
# coding: utf-8

import os
import re
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pickle
import argparse

def parse_rmd_ground_truth(rmd_file):
    """Parses the Rmd file to extract ground truth causal direction.
    
    Assumes the table format in the Rmd file where Variable 1 is listed first.
    We need to determine if Variable 1 -> Variable 2 or Variable 2 -> Variable 1.
    The original paper/metadata (not fully available here) specifies the direction.
    Let's assume for now, based on common convention in the original dataset source,
    that the first variable listed is the CAUSE and the second is the EFFECT,
    unless explicitly stated otherwise. We need the original pairmeta file for certainty.
    
    Update: The original website's README/meta-data file is crucial here.
    Let's try to download that first.
    
    If pairmeta is unavailable, we'll proceed with the assumption Var1->Var2 and add a note.
    
    Args:
        rmd_file (str): Path to the causal_description.Rmd file.

    Returns:
        dict: Dictionary mapping pair index (int) to ground truth direction 
              (e.g., 1 for Var1->Var2, -1 for Var2->Var1).
              Returns None if parsing fails.
    """
    ground_truth = {}
    try:
        with open(rmd_file, 'r') as f:
            content = f.read()
               # Find the table section
        table_match = re.search(r"""## Format.*?\n\n(.*?)\n\n## Source""", content, re.DOTALL | re.IGNORECASE)
        if not table_match:
            print("Warning: Could not find the table section in Rmd file.")
            return None
            
        table_content = table_match.group(1)
        
        # Regex to capture pair number and variable names
        # Handles potential tabs and variations in spacing
        # Example line: | 001      | Altitude       |Temperature | DWD dataset |
        pattern = re.compile(r"^\|\s*(\d+)\s*\|\s*(.*?)\s*\|\s*(.*?)\s*\|.*?\|$", re.MULTILINE)
        
        for match in pattern.finditer(table_content):
            pair_index = int(match.group(1))
            var1_name = match.group(2).strip()
            var2_name = match.group(3).strip()
            
            # *** CRITICAL ASSUMPTION ***
            # Without the original pairmeta file, we assume Var1 -> Var2.
            # This needs verification if the original metadata becomes available.
            # Let's default to 1 (Var1 -> Var2)
            direction = 1 
            ground_truth[pair_index] = {
                "var1": var1_name,
                "var2": var2_name,
                "direction": direction # 1: var1->var2, -1: var2->var1
            }
            
    except Exception as e:
        print(f"Error parsing Rmd file: {e}")
        return None
        
    if not ground_truth:
        print("Warning: No ground truth information extracted from Rmd file.")
        return None
        
    print(f"Extracted ground truth info for {len(ground_truth)} pairs (assuming Var1->Var2). Needs verification.")
    return ground_truth

def process_tuebingen_data(data_dir, output_dir, ground_truth_info):
    """Loads, standardizes, splits, and saves the Tübingen dataset.

    Args:
        data_dir (str): Directory containing the downloaded 'datasets' folder.
        output_dir (str): Directory to save the processed data.
        ground_truth_info (dict): Dictionary with ground truth info from parse_rmd_ground_truth.
    """
    datasets_path = os.path.join(data_dir, "datasets")
    if not os.path.isdir(datasets_path):
        print(f"Error: 'datasets' directory not found in {data_dir}")
        return

    os.makedirs(output_dir, exist_ok=True)
    
    all_pairs_data = []
    processed_indices = set()

    for filename in os.listdir(datasets_path):
        if filename.endswith(".csv"):
            match = re.match(r"causal_tubingen(\d+)\.csv", filename)
            if match:
                pair_index = int(match.group(1))
                if pair_index not in ground_truth_info:
                    print(f"Warning: Skipping pair {pair_index}, no ground truth info found.")
                    continue
                    
                filepath = os.path.join(datasets_path, filename)
                try:
                    # Load data, assuming comma separated, no header
                    df = pd.read_csv(filepath, sep=',', header=None)
                    if df.shape[1] != 2:
                        print(f"Warning: Skipping pair {pair_index}, expected 2 columns, found {df.shape[1]}.")
                        continue
                        
                    data = df.values
                    
                    # Standardize data (mean 0, variance 1)
                    scaler = StandardScaler()
                    scaled_data = scaler.fit_transform(data)
                    
                    pair_info = ground_truth_info[pair_index]
                    all_pairs_data.append({
                        "index": pair_index,
                        "data": scaled_data,
                        "var1_name": pair_info["var1"],
                        "var2_name": pair_info["var2"],
                        "direction": pair_info["direction"]
                    })
                    processed_indices.add(pair_index)
                    
                except Exception as e:
                    print(f"Error processing file {filename}: {e}")

    print(f"Successfully processed {len(all_pairs_data)} pairs.")
    
    # Check if all ground truth pairs were processed
    missing_indices = set(ground_truth_info.keys()) - processed_indices
    if missing_indices:
        print(f"Warning: Could not find data files for ground truth indices: {sorted(list(missing_indices))}")

    if not all_pairs_data:
        print("Error: No data processed. Exiting.")
        return

    # Split into train/val/test based on pair indices
    indices = [p["index"] for p in all_pairs_data]
    # Ensure consistent splitting
    indices.sort()
    
    # Use indices for splitting to keep pairs intact
    train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
    train_indices, val_indices = train_test_split(train_indices, test_size=0.125, random_state=42) # 0.125 * 0.8 = 0.1

    train_set = [p for p in all_pairs_data if p["index"] in train_indices]
    val_set = [p for p in all_pairs_data if p["index"] in val_indices]
    test_set = [p for p in all_pairs_data if p["index"] in test_indices]

    print(f"Split sizes: Train={len(train_set)}, Val={len(val_set)}, Test={len(test_set)}")

    # Save processed data
    output_file = os.path.join(output_dir, "tuebingen_pairs_processed.pkl")
    processed_data = {
        "train": train_set,
        "validation": val_set,
        "test": test_set,
        "ground_truth_assumption": "Assumed Var1 -> Var2, needs verification with original pairmeta file."
    }
    with open(output_file, "wb") as f:
        pickle.dump(processed_data, f)
    print(f"Saved processed Tübingen data to {output_file}")

def main(args):
    print("Processing Tübingen Cause-Effect Pairs dataset...")
    # First, try to get the definitive metadata if possible
    metadata_url = "https://webdav.tuebingen.mpg.de/cause-effect/pairmeta.txt"
    metadata_local_path = os.path.join(args.data_dir, "pairmeta.txt")
    # Simple attempt to download metadata
    try:
        import requests
        response = requests.get(metadata_url)
        if response.status_code == 200:
            with open(metadata_local_path, 'w') as f:
                f.write(response.text)
            print(f"Successfully downloaded metadata to {metadata_local_path}")
            # TODO: Implement parsing for pairmeta.txt if successful
            # ground_truth_info = parse_pairmeta(metadata_local_path)
            # For now, continue using Rmd parsing as fallback
            print("Note: Metadata downloaded, but parsing not implemented yet. Using Rmd fallback.")
            ground_truth_info = parse_rmd_ground_truth(os.path.join(args.data_dir, "causal_description.Rmd"))
        else:
            print(f"Warning: Failed to download metadata (Status: {response.status_code}). Using Rmd fallback.")
            ground_truth_info = parse_rmd_ground_truth(os.path.join(args.data_dir, "causal_description.Rmd"))
    except Exception as e:
        print(f"Warning: Could not download metadata ({e}). Using Rmd fallback.")
        ground_truth_info = parse_rmd_ground_truth(os.path.join(args.data_dir, "causal_description.Rmd"))

    if ground_truth_info:
        process_tuebingen_data(args.data_dir, args.output_dir, ground_truth_info)
    else:
        print("Error: Failed to obtain ground truth information. Cannot process dataset.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process Tübingen Cause-Effect Pairs Data")
    parser.add_argument("--data_dir", type=str, default="/home/ubuntu/ecam_project/data/real_world/tuebingen",
                        help="Directory containing downloaded Tübingen data (datasets folder, Rmd file)")
    parser.add_argument("--output_dir", type=str, default="/home/ubuntu/ecam_project/data/processed",
                        help="Directory to save processed data")
    
    args = parser.parse_args()
    main(args)

