import os
import sys
import torch
import numpy as np
import argparse
import json
from tqdm import tqdm
from contextlib import contextmanager

# Ensure we can find project modules
current_script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_script_dir)
sys.path.append(project_root)
os.chdir(project_root)
print(f"Current working directory set to: {os.getcwd()}")

from src.utils import parse_arguments
from src.models.task_vectors import NonLinearTaskVector
from src.finetune.sabcd_finetune import sabcd_finetune
from src.finetune.continual_finetune import continual_finetune
from src.models.model_utils import load_pretrained_model, apply_merged_vector
from src.merging.saim import SAIM

# ==============================
# Auxiliary functions and context managers
# ==============================

@contextmanager
def torch_gpu_scope():
    """Context manager for GPU memory management"""
    try:
        yield
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

def setup_parser():
    """Configure command line argument parser"""
    parser = argparse.ArgumentParser(description="Weight disentanglement visualization")
    parser.add_argument("--output_dir", default="disentanglement_visualization",
                       help="Output directory")
    parser.add_argument("--model", default="ViT-B-32", choices=["ViT-B-32", "ViT-B-16", "ViT-L-14"],
                       help="Model to use")
    parser.add_argument("--task_pairs", nargs='+', default=["EuroSAT,SUN397", "DTD,EuroSAT","GTSRB,SVHN", "DTD,MNIST"],
                       help="Task pairs to analyze, format 'task1,task2'")
    # Modify alpha range to -0.5~1.5
    parser.add_argument("--alpha_range", type=float, nargs=2, default=[-0.5, 1.5],
                       help="Alpha parameter range [min, max]")
    # Modify sampling points to 11
    parser.add_argument("--alpha_steps", type=int, default=11,
                       help="Number of alpha parameter steps")
    parser.add_argument("--device", type=str, default=None,
                       help="Computing device")
    parser.add_argument("--beta", type=float, default=1.0,
                       help="Beta parameter for SAIM method")
    parser.add_argument("--force_recompute", action="store_true",
                       help="Force recomputation of disentanglement error, do not use saved results")
    parser.add_argument("--data_location", type=str, default="./datasets",
                       help="Dataset storage location")
    parser.add_argument(
        "--model-location",
        type=str,
        default="./models/ckpts",
        help="Directory for model location",
    )
    parser.add_argument("--batch-size", type=int, default=64)
    return parser

def initialize_args(args):
    """Initialize and extend parameters"""
    # Set device
    if args.device is None:
        args.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Get base parameters
    base_args = parse_arguments()
    base_args.model = args.model
    base_args.batch_size = 32 if "ViT-L" in args.model else 128
    base_args.num_grad_accumulation = 4 if "ViT-L" in args.model else 1
    base_args.lr = 1e-5
    base_args.device = args.device
    base_args.data_location = "./datasets"
    
    # Add missing attributes
    base_args.model_location = "models/ckpts"
    base_args.save_dir = os.path.join(base_args.model_location, base_args.model)
    base_args.output_dir = args.output_dir
    
    return base_args

# ==============================
# Data storage and loading
# ==============================

def save_results_to_json(task1, task2, alpha_values, no_sabcd_error_map, sabcd_error_map, args):
    """Save calculation results to JSON file"""
    results = {
        "task1": task1,
        "task2": task2,
        "alpha_values": alpha_values.tolist(),
        "alpha_range": args.alpha_range,
        "alpha_steps": args.alpha_steps,
        "beta": args.beta,
        "model": args.model,
        "no_sabcd_error_map": no_sabcd_error_map.tolist(),
        "sabcd_error_map": sabcd_error_map.tolist(),
    }
    
    results_dir = os.path.join(args.output_dir, "results")
    os.makedirs(results_dir, exist_ok=True)
    
    output_file = os.path.join(results_dir, f"{task1}_{task2}_errors.json")
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Disentanglement error results saved to: {output_file}")
    return output_file

def load_results_from_json(task1, task2, args):
    """Load calculation results from JSON file"""
    results_dir = os.path.join(args.output_dir, "results")
    input_file = os.path.join(results_dir, f"{task1}_{task2}_errors.json")
    
    if not os.path.exists(input_file):
        print(f"Saved results file not found: {input_file}")
        return None
    
    try:
        with open(input_file, 'r') as f:
            results = json.load(f)
        
        # Check if parameters are consistent
        if (results["alpha_range"] != args.alpha_range or 
            results["alpha_steps"] != args.alpha_steps or
            results["beta"] != args.beta or
            results["model"] != args.model):
            print(f"Parameters inconsistent, need to recompute")
            return None
        
        print(f"Loading saved disentanglement error results from file: {input_file}")
        return {
            "alpha_values": np.array(results["alpha_values"]),
            "no_sabcd_error_map": np.array(results["no_sabcd_error_map"]),
            "sabcd_error_map": np.array(results["sabcd_error_map"]),
        }
    
    except Exception as e:
        print(f"Error loading results file: {e}")
        return None
    
# ==============================
# Model-related functions
# ==============================

def finetune_model(args, dataset, output_path, use_sabcd=False):
    """Finetune model and return model path"""
    dataset_val = dataset + "Val"
    
    # Load pretrained model for finetuning
    pretrained_state = load_pretrained_model(args)
    pretrained_path = os.path.join(args.output_dir, f"pretrained_{args.model}.pt")
    if not os.path.exists(pretrained_path):
        torch.save(pretrained_state, pretrained_path)
    
    # Check if finetuned model already exists
    if os.path.exists(output_path):
        print(f"Found existing finetuned model: {output_path}, skipping finetuning step")
        return output_path
    
    print(f"Finetuning on {dataset_val} (use_sabcd={use_sabcd})")
    if use_sabcd:
        sabcd_finetune(
            args=args,
            train_dataset=dataset_val, 
            starting_model_path=pretrained_path,
            output_path=output_path
        )
    else:
        continual_finetune(
            args=args,
            train_dataset=dataset_val,
            starting_model_path=pretrained_path,
            output_path=output_path
        )
    
    return output_path

def create_single_task_model(task_vector, base_state_dict, alpha, args):
    """Create single-task model"""
    device = torch.device(args.device)
    
    with torch_gpu_scope():
        merged_vector = SAIM(
            [task_vector], base_state_dict, base_state_dict, task_count=1, beta=args.beta)
        model = apply_merged_vector(
            base_state_dict,
            merged_vector,
            alpha=alpha,
            device=device,
            method="SAIM",
            model_name=task_vector.model_name
        )
        model = model.to(device)
        
        return model

def create_merged_model(task1_vector, task2_vector, base_state_dict, alpha1, alpha2, args):
    """Create merged model for two tasks"""
    device = torch.device(args.device)
    
    with torch_gpu_scope():
        # Apply task1 vector first
        merged_vector1 = SAIM(
            [task1_vector], base_state_dict, base_state_dict, task_count=1, beta=args.beta)
        model_after_task1 = apply_merged_vector(
            base_state_dict,
            merged_vector1,
            alpha=alpha1,
            device=device,
            method="SAIM",
            model_name=task1_vector.model_name
        )
        model_after_task1 = model_after_task1.to(device)
        
        # Get state dict of model_after_task1
        model_after_task1_state_dict = {}
        for name, param in model_after_task1.named_parameters():
            model_after_task1_state_dict[name] = param.data.clone()
        
        # Apply task2 vector
        merged_vector2 = SAIM(
            [task2_vector], model_after_task1_state_dict, base_state_dict, task_count=2, beta=args.beta)
        merged_model = apply_merged_vector(
            base_state_dict,
            merged_vector2,
            alpha=alpha2,
            device=device,
            method="SAIM",
            model_name=task2_vector.model_name
        )
        merged_model = merged_model.to(device)
        
        # Clean up intermediate model
        del model_after_task1
        
        return merged_model

# ==============================
# Evaluation and error calculation
# ==============================

def compute_model_output_distance(model1, model2, dataset_name, args):
    """Compute output distance between two models on given dataset using L1 norm"""
    # Set evaluation mode
    model1.eval()
    model2.eval()
    
    # Import functions from correct module
    from src.datasets.registry import get_dataset
    from src.datasets.common import get_dataloader, maybe_dictionarize
    
    with torch_gpu_scope():
        # Get dataset and dataloader
        dataset = get_dataset(
            dataset_name,
            model1.val_preprocess,  # Assume both models use the same preprocessing
            location=args.data_location,
            batch_size=args.batch_size,
        )
        dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
        
        total_distance = 0.0
        sample_count = 0
        
        with torch.no_grad():
            for batch in dataloader:
                data = maybe_dictionarize(batch)
                images = data["images"].to(args.device)
                
                # Get outputs from both models
                outputs1 = model1(images)
                outputs2 = model2(images)
                
                # Apply softmax and compute L1 norm (sum of absolute differences)
                outputs1 = torch.nn.functional.softmax(outputs1, dim=1)
                outputs2 = torch.nn.functional.softmax(outputs2, dim=1)
                
                # Compute L1 distance: sum(|outputs1 - outputs2|)
                l1_dist = torch.abs(outputs1 - outputs2).sum(dim=1).mean().item()
                
                total_distance += l1_dist * images.size(0)
                sample_count += images.size(0)
        
        return total_distance / sample_count

def calculate_disentanglement_error(task1_vector, task2_vector, base_state_dict, 
                                  args, alpha1, alpha2, task1_dataset, task2_dataset):
    """Compute disentanglement error"""
    with torch_gpu_scope():
        # Create single-task models
        task1_only_model = create_single_task_model(
            task1_vector, base_state_dict, alpha1, args
        )
        
        task2_only_model = create_single_task_model(
            task2_vector, base_state_dict, alpha2, args
        )
        
        # Create merged model
        merged_model = create_merged_model(
            task1_vector, task2_vector, base_state_dict, alpha1, alpha2, args
        )
        
        # Prepare datasets
        task1_val_dataset = task1_dataset + "Val"
        task2_val_dataset = task2_dataset + "Val"
        
        # Compute disentanglement error on both tasks
        task1_error = compute_model_output_distance(
            task1_only_model, merged_model, task1_val_dataset, args
        )
        
        task2_error = compute_model_output_distance(
            task2_only_model, merged_model, task2_val_dataset, args
        )
        
        # Total disentanglement error
        total_error = task1_error + task2_error
        
        # Release model memory
        del task1_only_model, task2_only_model, merged_model
        
        return total_error, task1_error, task2_error

def export_data_for_origin(task1, task2, alpha_values, no_sabcd_error_map, sabcd_error_map, args):
    """Export data in CSV and Excel formats usable by Origin"""
    export_dir = os.path.join(args.output_dir, "origin_data")
    os.makedirs(export_dir, exist_ok=True)
    
    # Create CSV data table (grid format)
    csv_output_file = os.path.join(export_dir, f"{task1}_{task2}_grid_data.csv")
    with open(csv_output_file, 'w') as f:
        # Write header
        header = ",".join(["alpha2/alpha1"] + [f"{a:.3f}" for a in alpha_values])
        f.write(f"{header}\n")
        
        # Write no SABCD data (with header)
        f.write("# No SABCD Data\n")
        for j, alpha2 in enumerate(alpha_values):
            row_data = [f"{alpha2:.3f}"] + [f"{no_sabcd_error_map[j, i]:.6f}" for i in range(len(alpha_values))]
            f.write(",".join(row_data) + "\n")
        
        # Write SABCD data (with header)
        f.write("\n# With SABCD Data\n")
        for j, alpha2 in enumerate(alpha_values):
            row_data = [f"{alpha2:.3f}"] + [f"{sabcd_error_map[j, i]:.6f}" for i in range(len(alpha_values))]
            f.write(",".join(row_data) + "\n")
    
    # Create another CSV format (XYZ format, easier to import)
    xyz_output_file = os.path.join(export_dir, f"{task1}_{task2}_xyz_data.csv")
    with open(xyz_output_file, 'w') as f:
        # Write header
        f.write("alpha1,alpha2,no_sabcd_error,sabcd_error\n")
        
        # Write data
        for i, alpha1 in enumerate(alpha_values):
            for j, alpha2 in enumerate(alpha_values):
                f.write(f"{alpha1:.6f},{alpha2:.6f},{no_sabcd_error_map[j, i]:.6f},{sabcd_error_map[j, i]:.6f}\n")
    
    print(f"Exported Origin grid data to: {csv_output_file}")
    print(f"Exported Origin XYZ data to: {xyz_output_file}")
    return xyz_output_file

# ==============================
# Experimental flow functions
# ==============================

def prepare_task_vectors(task1, task2, base_args, args):
    """Prepare task vectors"""
    # Create finetuned model directory
    models_dir = os.path.join(args.output_dir, "models")
    os.makedirs(models_dir, exist_ok=True)
    
    # Load pretrained model
    pretrained_state = load_pretrained_model(base_args)
    
    # Prepare finetuned models and vectors for task1
    task1_no_sabcd_path = os.path.join(models_dir, f"{task1}_no_sabcd.pt")
    task1_sabcd_path = os.path.join(models_dir, f"{task1}_sabcd.pt")
    
    finetune_model(base_args, task1, task1_no_sabcd_path, use_sabcd=False)
    finetune_model(base_args, task1, task1_sabcd_path, use_sabcd=True)
    
    # Prepare finetuned models and vectors for task2
    task2_no_sabcd_path = os.path.join(models_dir, f"{task2}_no_sabcd.pt")
    task2_sabcd_path = os.path.join(models_dir, f"{task2}_sabcd.pt")
    
    finetune_model(base_args, task2, task2_no_sabcd_path, use_sabcd=False)
    finetune_model(base_args, task2, task2_sabcd_path, use_sabcd=True)
    
    # Compute task vectors
    task1_no_sabcd_state = torch.load(task1_no_sabcd_path, map_location='cpu')
    task2_no_sabcd_state = torch.load(task2_no_sabcd_path, map_location='cpu')
    
    task1_no_sabcd_vector = NonLinearTaskVector(base_args.model, pretrained_state, task1_no_sabcd_state)
    task2_no_sabcd_vector = NonLinearTaskVector(base_args.model, pretrained_state, task2_no_sabcd_state)
    
    task1_sabcd_state = torch.load(task1_sabcd_path, map_location='cpu')
    task2_sabcd_state = torch.load(task2_sabcd_path, map_location='cpu')
    
    task1_sabcd_vector = NonLinearTaskVector(base_args.model, pretrained_state, task1_sabcd_state)
    task2_sabcd_vector = NonLinearTaskVector(base_args.model, pretrained_state, task2_sabcd_state)
    
    return {
        'pretrained_state': pretrained_state,
        'no_sabcd': {
            'task1_vector': task1_no_sabcd_vector,
            'task2_vector': task2_no_sabcd_vector
        },
        'sabcd': {
            'task1_vector': task1_sabcd_vector,
            'task2_vector': task2_sabcd_vector
        }
    }

def compute_error_maps(task_vectors, pretrained_state, task1, task2, alpha_values, args):
    """Compute complete error map"""
    no_sabcd_error_map = np.zeros((len(alpha_values), len(alpha_values)))
    sabcd_error_map = np.zeros((len(alpha_values), len(alpha_values)))
    
    # Compute error map without SABCD
    print("\nComputing weight disentanglement error without SABCD:")
    total_steps = len(alpha_values) * len(alpha_values)
    
    with tqdm(total=total_steps, desc="No SABCD progress") as pbar:
        for i, alpha1 in enumerate(alpha_values):
            for j, alpha2 in enumerate(alpha_values):
                total_error, t1_err, t2_err = calculate_disentanglement_error(
                    task_vectors['no_sabcd']['task1_vector'],
                    task_vectors['no_sabcd']['task2_vector'],
                    pretrained_state, args, alpha1, alpha2, task1, task2
                )
                no_sabcd_error_map[j, i] = total_error
                
                pbar.update(1)
                pbar.set_postfix(alpha1=f"{alpha1:.2f}", alpha2=f"{alpha2:.2f}", 
                                 error=f"{total_error:.4f}", t1_err=f"{t1_err:.4f}", t2_err=f"{t2_err:.4f}")
    
    # Compute error map with SABCD
    print("\nComputing weight disentanglement error with SABCD:")
    
    with tqdm(total=total_steps, desc="SABCD progress") as pbar:
        for i, alpha1 in enumerate(alpha_values):
            for j, alpha2 in enumerate(alpha_values):
                total_error, t1_err, t2_err = calculate_disentanglement_error(
                    task_vectors['sabcd']['task1_vector'],
                    task_vectors['sabcd']['task2_vector'],
                    pretrained_state, args, alpha1, alpha2, task1, task2
                )
                sabcd_error_map[j, i] = total_error
                
                pbar.update(1)
                pbar.set_postfix(alpha1=f"{alpha1:.2f}", alpha2=f"{alpha2:.2f}", 
                                 error=f"{total_error:.4f}", t1_err=f"{t1_err:.4f}", t2_err=f"{t2_err:.4f}")
    
    return no_sabcd_error_map, sabcd_error_map

def process_task_pair(task_pair, alpha_values, base_args, args):
    """Process a pair of tasks, only compute and export data, no plotting"""
    task1, task2 = task_pair.split(",")
    print(f"\n===== Processing task pair: {task1}-{task2} =====")
    
    # Check if saved results exist
    if not args.force_recompute:
        loaded_results = load_results_from_json(task1, task2, args)
        if loaded_results is not None:
            alpha_values = loaded_results["alpha_values"]
            no_sabcd_error_map = loaded_results["no_sabcd_error_map"]
            sabcd_error_map = loaded_results["sabcd_error_map"]
            
            print(f"Loaded calculation results, exporting data for Origin plotting...")
            export_data_for_origin(task1, task2, alpha_values, no_sabcd_error_map, sabcd_error_map, args)
            return
    
    # Prepare task vectors
    task_vectors = prepare_task_vectors(task1, task2, base_args, args)
    pretrained_state = task_vectors['pretrained_state']
    
    # Compute complete error map
    no_sabcd_error_map, sabcd_error_map = compute_error_maps(
        task_vectors, pretrained_state, task1, task2, alpha_values, args
    )
    
    # Save results
    save_results_to_json(task1, task2, alpha_values, no_sabcd_error_map, sabcd_error_map, args)
    
    # Export data for Origin plotting
    export_data_for_origin(task1, task2, alpha_values, no_sabcd_error_map, sabcd_error_map, args)
    print(f"Data exported for Origin plotting")

# ==============================
# Main function
# ==============================

def main():
    # Parse parameters
    parser = setup_parser()
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize parameters
    base_args = initialize_args(args)
    
    # Generate alpha values range
    alpha_values = np.linspace(args.alpha_range[0], args.alpha_range[1], args.alpha_steps)
    
    # Process each task pair
    for task_pair in args.task_pairs:
        process_task_pair(task_pair, alpha_values, base_args, args)

if __name__ == "__main__":
    main()