import os
import sys

# Set working directory to current script directory
current_script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_script_dir)  # Get project root directory
sys.path.append(project_root)  # Add project root directory to module search path
os.chdir(project_root)  # Set working directory to project root directory

print(f"Current working directory set to: {os.getcwd()}")

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
import random

from src.utils import parse_arguments
from src.finetune.sabcd_finetune import sabcd_finetune
from src.finetune.continual_finetune import continual_finetune
from src.models.model_utils import load_pretrained_model

# Set random seed to ensure reproducible results
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

# Task mapping to actual datasets
TASK_DATASETS = {
    "OCR": ["MNIST", "EMNIST", "KMNIST"],  # Optical character recognition related datasets
    "VQA": ["STL10", "CIFAR100", "RenderedSST2"],  # Visual question answering related datasets
    "Geometry": ["GTSRB", "RESISC45", "EuroSAT"],  # Geometric shape related datasets
    "Chart": ["CIFAR10", "Food101", "SUN397"],  # Chart related datasets
    "Grounding": ["Cars", "OxfordIIITPet", "DTD"]  # Datasets selected for Grounding task
}

# Task corresponding colors - matching original chart colors
TASK_COLORS = {
    "OCR": "#62A0CA",      # Blue
    "VQA": "#FFA556",      # Orange
    "Geometry": "#6ABC6A", # Green
    "Chart": "#E36868",    # Red
    "Grounding": "#B494D0" # Purple
}

def calculate_param_changes(old_params, new_params):
    """Calculate the magnitude of parameter changes (log10(|new_param - old_param|))"""
    changes = []
    
    for key in old_params:
        if key in new_params and 'visual' in key:  # Mainly focus on visual model part
            # Ensure both tensors are on the same device (move to CPU)
            old_param = old_params[key].cpu()
            new_param = new_params[key].cpu()
            
            # Calculate absolute value of parameter changes
            param_change = (new_param - old_param).abs()
            
            # Only consider non-zero changes
            non_zero_mask = param_change > 1e-10
            
            non_zero_changes = param_change[non_zero_mask]
            if len(non_zero_changes) > 0:
                # Convert non-zero changes to log10 values
                log_changes = torch.log10(non_zero_changes).flatten().cpu().numpy()
                changes.extend(log_changes)
    
    return np.array(changes)

def calculate_param_magnitudes(params):
    """Calculate the magnitude of parameter absolute values (log10(|param|))"""
    magnitudes = []
    
    for key, param in params.items():
        if 'visual' in key:  # Mainly focus on visual model part
            # Ensure parameter is on CPU
            param = param.cpu()
            
            # Calculate parameter absolute value
            param_abs = param.abs()
            non_zero_mask = param_abs > 1e-10
            
            non_zero_magnitudes = param_abs[non_zero_mask]
            if len(non_zero_magnitudes) > 0:
                # Convert non-zero absolute values to log10 values
                log_magnitudes = torch.log10(non_zero_magnitudes).flatten().cpu().numpy()
                magnitudes.extend(log_magnitudes)
    
    return np.array(magnitudes)

def plot_param_changes_histogram(changes_dict, title="Parameter Change Distribution", 
                               save_path="param_changes_histogram.png", 
                               bins=75, alpha=0.6, max_samples=100000):
    """Plot parameter change magnitude distribution histogram"""
    # Set chart style
    plt.style.use('default')
    plt.figure(figsize=(10, 6))
    
    # Determine common x-axis range
    x_range = (-6.5, -2.5)
    
    # Calculate common bin boundaries
    bin_edges = np.linspace(x_range[0], x_range[1], bins+1)
    
    # Plot charts in specific order
    tasks_to_plot = ["Chart", "Geometry", "VQA", "OCR", "Grounding"]

    for task in reversed(tasks_to_plot):  # Reverse order so tasks appear on top layer
        if task in changes_dict:
            data = changes_dict[task]
            
            # Data processing
            data = np.array(data)
            valid_mask = np.isfinite(data)
            valid_data = data[valid_mask]
            
            # Sample large datasets
            if len(valid_data) > max_samples:
                indices = np.random.choice(len(valid_data), max_samples, replace=False)
                valid_data = valid_data[indices]
            
            # Plot histogram
            plt.hist(valid_data, bins=bin_edges, alpha=alpha, 
                     label=task, color=TASK_COLORS.get(task, "#333333"),
                     edgecolor='none', linewidth=0.5, density=True)
    
    # Beautify chart
    plt.xlabel('The magnitude of parameter changes (Log10)', fontsize=26)
    plt.ylabel('Density', fontsize=26)
    # plt.title(title, fontsize=22)

    # Place legend on left side and increase size
    plt.legend(fontsize=24, loc='upper left', frameon=True)

    # Adjust grid line style
    plt.grid(True, alpha=0.5, linestyle='-', linewidth=0.5)
    
    # Adjust X-axis range and ticks
    plt.xlim(x_range)
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)
    
    # Adjust Y-axis
    plt.ylim(bottom=0)
    plt.ylim(top=1.25)  # Fix Y-axis upper limit
    
    # Add border
    plt.box(True)
    
    # Optimize layout
    plt.tight_layout()
    
    # Save chart
    try:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Chart saved to: {save_path}")
    except Exception as e:
        print(f"Failed to save chart: {str(e)}")
    
    plt.close()
    
def calculate_model_param_changes(old_model_path, new_model_path):
    """Calculate the magnitude of parameter changes between two models"""
    old_state = torch.load(old_model_path)
    new_state = torch.load(new_model_path)
    return calculate_param_changes(old_state, new_state)
    
def finetune_task(args, task, datasets, pretrained_path, output_dir, method="sabcd"):
    """Finetune on specified task"""
    task_changes = []
    
    # Finetune each dataset of the task
    for dataset in datasets:
        dataset_val = dataset + "Val"
        print(f"Finetuning {task} task on {dataset_val} dataset with {method}...")
        
        # Finetuning output path
        output_path = os.path.join(output_dir, f"{method}_{task}_{dataset}.pt")
        
        # If already finetuned, load directly
        if os.path.exists(output_path):
            print(f"Found existing finetuned model: {output_path}")
        else:
            # Finetune according to specified method
            if method.lower() == "sabcd":
                sabcd_finetune(
                    args=args,
                    train_dataset=dataset_val,
                    starting_model_path=pretrained_path,
                    output_path=output_path
                )
            else:  # Default to adam finetuning
                continual_finetune(
                    args=args,
                    train_dataset=dataset_val,
                    starting_model_path=pretrained_path,
                    output_path=output_path
                )
        
        # Load finetuned model
        finetuned_state = torch.load(output_path)
        
        # Load pretrained model
        pretrained_state = torch.load(pretrained_path)
        
        # Calculate parameter changes
        changes = calculate_param_changes(pretrained_state, finetuned_state)
        task_changes.extend(changes)
        
        print(f"{task}-{dataset}: sample count={len(changes)}, average change magnitude={np.mean(changes):.4f}")
    
    return np.array(task_changes)

def main():
    set_seed(42)
    
    parser = argparse.ArgumentParser(description="Analysis of the impact of different task finetuning on model parameters")
    parser.add_argument("--output_dir", default="analyze/param_change_visualization",
                        help="Output directory")
    parser.add_argument("--skip_existing", action="store_true",
                        help="Whether to skip existing finetuned models")
    parser.add_argument("--model", default="ViT-B-32", choices=["ViT-B-32", "ViT-B-16", "ViT-L-14"],
                        help="Model to use")
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Get base parameters
    base_args = parse_arguments()
    base_args.model = args.model
    base_args.batch_size = 32 if "ViT-L" in args.model else 128
    base_args.num_grad_accumulation = 4 if "ViT-L" in args.model else 1
    base_args.lr = 1e-5
    
    # Add missing save_dir and model_location attributes
    base_args.model_location = "models/ckpts"
    base_args.save_dir = os.path.join(base_args.model_location, base_args.model)
    
    # Load pretrained model
    print("Loading pretrained model...")
    pretrained_state = load_pretrained_model(base_args)
    pretrained_path = os.path.join(args.output_dir, "pretrained.pt")
    if not os.path.exists(pretrained_path):
        torch.save(pretrained_state, pretrained_path)
    
    # Process two finetuning methods
    # Process two finetuning methods
    for method in ["sabcd", "adam"]:
        print(f"\n===== Processing {method.upper()} finetuning method =====")
        
        # Store parameter changes for each task
        all_changes = {}
        
        # Finetune each task
        for task, datasets in TASK_DATASETS.items():
            print(f"\n----- Processing {task} task ({method} finetuning) -----")
            task_changes = finetune_task(
                args=base_args,
                task=task,
                datasets=datasets,
                pretrained_path=pretrained_path,
                output_dir=args.output_dir,
                method=method
            )
            
            all_changes[task] = task_changes
        
        # Plot parameter change magnitude comparison chart
        print(f"\nPlotting task parameter change magnitude comparison for {method.upper()} finetuning method...")
        # Modify image naming
        if method == "sabcd":
            pic_name = "SA-BCD_param_changes_comparison.png"
            title_name = "Parameter Change Distribution (SA-BCD Fine-tuning)"
        else:
            pic_name = "Adam_param_changes_comparison.png"
            title_name = "Parameter Change Distribution (Adam Fine-tuning)"
        plot_param_changes_histogram(
            all_changes,
            title=title_name,
            save_path=os.path.join(args.output_dir, pic_name)
        )
        
        # Print statistical data
        print(f"\n===== {method.upper()} Finetuning Parameter Change Statistics =====")
        for task, changes in all_changes.items():
            print(f"{task}:")
            print(f"  Sample count: {len(changes)}")
            print(f"  Average change magnitude: {np.mean(changes):.4f}")
            print(f"  Median change magnitude: {np.median(changes):.4f}")
            print(f"  Standard deviation: {np.std(changes):.4f}")

if __name__ == "__main__":
    main()