import os
import gc
import torch
import json
import traceback
import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer

from config.arguments import parse_args
from merging.saim import SAIM
from merging.model_utils import get_task_vector_dict
from finetune.finetune_utils import finetune_model
from trace_evaluation.trace_evaluator import evaluate_model_accuracy
from utils.data_utils import load_local_dataset, get_available_datasets
from utils.experiment_utils import create_experiment_dir, setup_conda_environments, setup_hf_cache, huggingface_login
# Add imports for other merging methods
from merging.task_arithmetic import task_arithmetic_merge, save_cumulative_vector_to_disk
from merging.magmax import magmax_merge, save_magmax_vector_to_disk
from merging.ties_merge import ties_merge, save_ties_vector_to_disk
from merging.dare import dare_merge, save_dare_vector_to_disk
from merging.swa import swa_merge, save_swa_state_to_disk
from utils.experiment_utils import restore_merge_state, save_merge_config, task_specific_batch_sizes, task_specific_epochs

# Global variable to store the model state dictionary from the last merge
PREVIOUS_MODEL_STATE = {"current_state": None}


def run_continual_learning(args):
    """Execute the main function for continual learning"""
    # Check and set up conda environment
    if not setup_conda_environments():
        print("Warning: conda environment setup is incomplete, which may affect the evaluation process")

    # Set environment variables and login
    setup_hf_cache(args.cache_dir)
    # Handle HuggingFace login
    token = args.token
    huggingface_login(token)

    # Set CUDA device
    gpu_id = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
    print(f"Using GPU: {gpu_id}")

    # Load base model into memory for computing task vectors
    print(f"Loading base model: {args.base_model}")
    base_model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="cpu",
        low_cpu_mem_usage=True,
        cache_dir=os.path.join(args.cache_dir, "transformers")
    )
    base_tokenizer = AutoTokenizer.from_pretrained(
        args.base_model, cache_dir=os.path.join(args.cache_dir, "transformers"))

    # Use fixed task list
    task_names = ['C-STANCE', 'FOMC', 'MeetingBank', 'ScienceQA', 'NumGLUE-cm', 'NumGLUE-ds', '20Minuten']
    print(f"Will use the following 7 datasets as continual learning tasks: {task_names}")

    # Check if datasets exist
    available_datasets = get_available_datasets()
    for task in task_names:
        if task not in available_datasets:
            print(f"Warning: Task '{task}' does not exist in the datasets directory, please ensure it is downloaded")
        
    if args.start_task < 0 or args.start_task > len(task_names):
        raise ValueError(f"Start task index must be between 0-{len(task_names)}")
    if args.end_task >= len(task_names):
        raise ValueError(f"End task index must be between 0-{len(task_names)-1}")

    # Create or load experiment directory
    if args.continue_experiment and args.prev_experiment_dir:
        result_dir = args.prev_experiment_dir
        print(f"\n===== Continuing experiment: {result_dir} =====")

        # Load previous experiment results
        try:
            with open(os.path.join(result_dir, "experiment_results.json"), "r") as f:
                experiment_results = json.load(f)
        except:
            experiment_results = {}

        # Determine the last processed task
        task_count = args.start_task
        prev_task_idx = args.start_task - 1
        
        # Restore previous merge state based on selected merge method
        if prev_task_idx >= 0:
            # Use general restore function to restore state based on different merge methods
            restore_merge_state(args.merge_method, result_dir, prev_task_idx, 
                               PREVIOUS_MODEL_STATE, args.cache_dir)
        # Get recorded finetune times from existing results
        all_finetune_times = []
        for key in experiment_results:
            if key.startswith("task_") and "finetune_time_seconds" in experiment_results[key]:
                all_finetune_times.append(experiment_results[key]["finetune_time_seconds"])
            
    else:
        result_dir = create_experiment_dir(args.save_path)
        print(f"\n===== Creating new experiment: {result_dir} =====")
        experiment_results = {}
        task_count = 0
        all_finetune_times = []

    # Create directory
    finetune_models_dir = os.path.join(result_dir, "finetunedModels")
    os.makedirs(finetune_models_dir, exist_ok=True)

    # Use more general config save function
    save_merge_config(
        result_dir=result_dir,
        merge_method=getattr(args, 'merge_method', 'SAIM'),
        base_model=args.base_model,
        task_names=task_names,
        start_task=args.start_task,
        end_task=args.end_task,
        scaling_coef=args.scaling_coef,
        use_default_scaling=getattr(args, 'use_default_scaling', False),
        task_vector_from_base=getattr(args, 'task_vector_from_base', True),
        continue_experiment=args.continue_experiment,
        prev_experiment_dir=args.prev_experiment_dir if args.continue_experiment else None,
    )

    # Handle special case: when start_task equals len(task_names), it means all tasks have been merged,
    # just load the final merged model for evaluation
    if args.start_task == len(task_names):
        print(f"\n{'='*50}")
        print(f"All tasks have been merged, load final model and evaluate performance on all tasks")
        print(f"{'='*50}")
        
        # Load the path of the final merged model
        final_model_path = os.path.join(result_dir, f"model_after_task_{len(task_names)}")
        
        if not os.path.exists(final_model_path):
            raise ValueError(f"Final merged model not found: {final_model_path}")
        
        print(f"Loading final merged model: {final_model_path}")
        
        # Evaluate final model performance on all tasks
        task_results = {}
        print(f"\nEvaluating final merged model performance on all tasks:")
        for task_idx, task_name in enumerate(task_names):
            print(f"\nEvaluating performance on task {task_name}:")
            task_results[task_name] = evaluate_model_accuracy(
                final_model_path, task_name, result_dir=result_dir)
        
        # Record results
        experiment_results["final_evaluation"] = {
            "task_results": task_results,
            "completed_tasks": task_names,
            "evaluation_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }
        
        # Save experiment results
        with open(os.path.join(result_dir, "experiment_results.json"), "w") as f:
            json.dump(experiment_results, f, indent=2)
            
        print(f"\n{'='*50}")
        print(f"Final evaluation results saved in: {result_dir}")
        print(f"{'='*50}")
        
        return

    # Process each task
    for i in range(args.start_task, args.end_task + 1):
        task_name = task_names[i]
        print(f"\n{'='*50}")
        print(f"Processing task {i+1}/{args.end_task + 1}: {task_name}")
        print(f"{'='*50}")

        # Update task count
        task_count += 1
        
        # Set task-specific parameters
        args.epochs = task_specific_epochs.get(task_name, args.epochs)
        args.batch_size = task_specific_batch_sizes.get(task_name, args.batch_size)
        print(f"Setting epochs for task {task_name}: {args.epochs}, setting batch_size: {args.batch_size}")
        
        # Step 1: Finetune
        finetune_model_path = os.path.join(
            finetune_models_dir, f"finetuned_{task_name}")
            
        # Used to store finetune time
        finetune_time_seconds = 0
        
        if os.path.exists(finetune_model_path) and args.continue_experiment:
            print(f"Found existing finetuned model: {finetune_model_path}, skipping finetune step")
        else:
            if i == 0 or args.train_from_base:
                # First task always uses base model
                start_model_path = args.base_model
                print(f"Using base model for first task: {start_model_path}")
            else:
                # For all merge methods, try to use the merged model from the previous task
                prev_model_path = os.path.join(result_dir, f"model_after_task_{i}")
                
                if os.path.exists(prev_model_path) and os.path.isdir(prev_model_path):
                    print(f"Using merged model from previous task for task {task_name}: {prev_model_path}")
                    start_model_path = prev_model_path
                else:
                    # If previous task model not found, use base model
                    print(f"Warning: Merged model from previous task not found, using base model for task {task_name}")
                    start_model_path = args.base_model

            # Load dataset for finetuning
            print(f"Loading dataset {task_name} for finetuning...")
            train_dataset = load_local_dataset(task_name, split='train')
            if train_dataset:
                # Add task dataset to args
                args.task_dataset = {"train": train_dataset}
                # Perform finetuning and get time
                _, finetune_time_seconds = finetune_model(args, task_name, start_model_path, finetune_model_path)
                print(f"Finetuning task {task_name} completed, took {finetune_time_seconds:.2f} seconds")
            else:
                raise ValueError(f"Unable to load training data for dataset {task_name}")
            
            all_finetune_times.append(finetune_time_seconds)
            
        # Step 2: Compute current task vector
        print("\nComputing current task vector")

        # For SWA method, no need to compute task vector
        if args.merge_method == "swa":
            print("Using SWA merge method, skipping task vector computation")
            current_vector_dict = None
        else:
            # Load finetuned model
            ft_model = AutoModelForCausalLM.from_pretrained(
                finetune_model_path,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                device_map="cpu",
                low_cpu_mem_usage=True,
                cache_dir=os.path.join(args.cache_dir, "transformers")
            )

            # Decide task vector computation method based on user choice
            if args.task_vector_from_base or args.train_from_base:
                # Method 1: Finetuned model - original pre-trained model
                print("Computing task vector: finetuned model - original pre-trained model")
                current_vector_dict = get_task_vector_dict(ft_model, base_model)
            else:
                # Method 2: Finetuned model - current merged model
                print("Computing task vector: finetuned model - current merged model")

                # Create temporary model for vector computation
                if i == 0 or  not PREVIOUS_MODEL_STATE.get("current_state"):
                    # First task or no previous merged model, use base model
                    reference_model = base_model
                else:
                    # For other merge methods, load from previous model
                    prev_model_path = os.path.join(result_dir, f"model_after_task_{i}")
                    if os.path.exists(prev_model_path):
                        reference_model = AutoModelForCausalLM.from_pretrained(
                            prev_model_path,
                            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                            device_map="cpu",
                            low_cpu_mem_usage=True,
                            cache_dir=os.path.join(args.cache_dir, "transformers")
                        )
                    else:
                        print(f"Warning: Merged model from previous task not found, using base model as reference")
                        reference_model = base_model

                # Compute task vector
                current_vector_dict = get_task_vector_dict(ft_model, reference_model)

            # Release finetuned model memory
            del ft_model
            if 'reference_model' in locals() and reference_model is not base_model:
                del reference_model
            gc.collect()

        # Step 3: Perform merge using selected method
        # Path to save intermediate and final models
        current_model_path = os.path.join(
            result_dir, f"model_after_task_{i+1}")

        # Determine whether to use default scaling coefficient
        scaling_coef_to_use = None if getattr(args, 'use_default_scaling', False) else args.scaling_coef

        # Perform merge operation
        print(f"Performing merge operation using method: {args.merge_method}...")

        if args.merge_method == "SAIM":
            # Use Adaptive-ISO merge
            merged_model = SAIM(
                base_model_path=args.base_model,
                current_vector_dict=current_vector_dict,
                task_index=i,
                task_count=task_count,
                scaling_coef=scaling_coef_to_use,
                previous_model_state=PREVIOUS_MODEL_STATE,
                cache_dir=args.cache_dir
            )
            
        elif args.merge_method == "task_arithmetic":
            # Use task arithmetic merge
            merged_model = task_arithmetic_merge(
                base_model_path=args.base_model,
                current_vector_dict=current_vector_dict,
                task_index=i,
                task_count=task_count,
                scaling_coef=scaling_coef_to_use,
                cache_dir=args.cache_dir
            )
            # Save cumulative task vector to disk
            save_cumulative_vector_to_disk(result_dir)
            
        elif args.merge_method == "magmax":
            # Use MagMax merge
            merged_model = magmax_merge(
                base_model_path=args.base_model,
                current_vector_dict=current_vector_dict,
                task_index=i,
                task_count=task_count,
                scaling_coef=scaling_coef_to_use,
                cache_dir=args.cache_dir
            )
            # Save MagMax cumulative task vector to disk
            save_magmax_vector_to_disk(result_dir)
            
        elif args.merge_method == "ties_merge":
            # Use TIES merge
            merged_model = ties_merge(
                base_model_path=args.base_model,
                current_vector_dict=current_vector_dict,
                task_index=i,
                task_count=task_count,
                scaling_coef=scaling_coef_to_use,
                cache_dir=args.cache_dir
            )
            # Save TIES cumulative task vector to disk
            save_ties_vector_to_disk(result_dir)
            
        elif args.merge_method == "dare":
            # Use DARE merge
            merged_model = dare_merge(
                base_model_path=args.base_model,
                current_vector_dict=current_vector_dict,
                task_index=i,
                task_count=task_count,
                scaling_coef=scaling_coef_to_use,
                drop_rate=0.9,
                mask_strategy="random", 
                cache_dir=args.cache_dir
            )
            # Save DARE cumulative task vector to disk
            save_dare_vector_to_disk(result_dir)
            
        elif args.merge_method == "swa":
            # Use SWA merge
            merged_model = swa_merge(
                base_model_path=args.base_model,
                task_name=task_name,
                task_index=i,
                task_count=task_count,
                finetuned_model_prefix=os.path.join(finetune_models_dir, "finetuned"),
                cache_dir=args.cache_dir
            )
            # Save SWA cumulative state to disk
            save_swa_state_to_disk(result_dir)
        else:
            raise ValueError(f"Unsupported merge method: {args.merge_method}")

        # Save model
        print(f"Saving merged model to: {current_model_path}")
        # First, explicitly convert model data type
        if torch.cuda.is_available():
            print("Converting model to bfloat16 and saving...")
            merged_model = merged_model.to(torch.bfloat16)
        else:
            print("Converting model to float16 and saving...")
            merged_model = merged_model.to(torch.float16)

        # Save model, explicitly specify half-precision format
        merged_model.save_pretrained(
            current_model_path, 
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
        )
        base_tokenizer.save_pretrained(current_model_path)

        # Step 4: Evaluate performance
        task_results = {}

        # Current model path used
        eval_model_path = current_model_path

        # Decide whether to evaluate immediately based on parameters
        if not args.evaluate_at_end or i == args.end_task:
            # Evaluate each task immediately
            print(f"\nEvaluating merged model performance on all previous tasks:")
            for prev_task_idx in range(0, i+1):
                prev_task = task_names[prev_task_idx]
                print(f"\nEvaluating performance on task {prev_task}:")
                task_results[prev_task] = evaluate_model_accuracy(
                    eval_model_path, prev_task, result_dir=result_dir)  # Pass result_dir parameter
                    
            # Record results, add finetune time
            experiment_results[f"task_{i+1}_{task_name}"] = {
                "task_results": task_results,
                "completed_tasks": task_names[:i+1],
                "finetune_time_seconds": finetune_time_seconds
            }

            # Save experiment results
            with open(os.path.join(result_dir, "experiment_results.json"), "w") as f:
                json.dump(experiment_results, f, indent=2)
                
            # Save all task finetune times and average
            if all_finetune_times:
                experiment_results["all_finetune_times_seconds"] = all_finetune_times
                experiment_results["average_finetune_time_seconds"] = sum(all_finetune_times) / len(all_finetune_times)
                with open(os.path.join(result_dir, "experiment_results.json"), "w") as f:
                    json.dump(experiment_results, f, indent=2)
                print(f"\nAll task finetune times: {all_finetune_times}")
                print(f"Average finetune time: {experiment_results['average_finetune_time_seconds']:.2f} seconds")

        # Clean up memory
        del merged_model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Experiment end, release base model memory
    del base_model, base_tokenizer
    gc.collect()

    print(f"\n{'='*50}")
    print(f"Experiment results saved in: {result_dir}")
    print(f"{'='*50}")


if __name__ == "__main__":
    args = parse_args()
    try:
        run_continual_learning(args)
    except Exception as e:
        print(f"\nError during continual learning: {e}")
        traceback.print_exc()
    #finally:
    #    os.system("shutdown")
    #os.system("shutdown")