import os
import pdb
import torch
from copy import deepcopy

import numpy as np
from utils import *
from task_merger import get_merge_handler

# Set TOKENIZERS_PARALLELISM to true
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"]="1"

import transformers
transformers.utils.logging.set_verbosity(transformers.logging.ERROR)

from huggingface_hub import login
# Get the token from environment variables
# token = os.getenv('HUGGINGFACE_TOKEN')
# login(token=token)

import argparse

def run_BIG_function(eval_split, bigseed, config_name, precision, test_sensitivity):
    set_seed(bigseed)
    # Use arguments
    print(f"Evaluating on split: {eval_split}")
    print(f"Using seed: {bigseed}")
    print(f"Using config: {config_name}")

    # CONFIG_NAME = 'llama8B_r16_knots_ties'
    TASK_HEADS_PATH = "heads.pt" #can be found on one-drive KnOTS_model_ckpts/Llama-3-8B/heads.pt
    COMPUTE_TRANSFORM = False
    # EVAL_SPLIT = 'val'
    # BIGSEED = 420
    EVAL_TEST = True

    print("Seed : ", bigseed)
    set_seed(bigseed)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    raw_config = get_config_from_name(config_name, device=device)
    raw_config['precision'] = precision
    print(raw_config['task_merge_config'])
    config = prepare_experiment_config(raw_config)
    dataset_names = np.array([i['name'] for i in raw_config['dataset']])
    dataloaders = np.array([i for i in config['data']])
    mask_class = np.array([i['mask_class'] for i in config['dataset']])
    print(f"mask_class labels: {mask_class}")
    
    transform_listified = [str(i) if k != 'ingredients_path' else os.path.basename(i).replace('.pt', '') for k, i in raw_config['task_merge_config'].items()]
    transform_listified += [str(v) for k, v in raw_config['model']['ft_config'].items() if k in {'r', 'type', 'lora_alpha'}]
    csv_file = os.path.join(
        './csvs_new',
        ":".join(dataset_names),
        raw_config['model']['name'],
        raw_config['eval_type'],
        ":".join(transform_listified),
        f'{eval_split}.csv'
    )
    os.makedirs(os.path.dirname(csv_file), exist_ok=True)
    print(f'Saving results to {csv_file}')

    #default_params = {
    #    'scaling_coeffs': 0.3,
    #    'topK': 70,
    #    'cart_pruning_rank': 0.04,
    #    'dare_pruning_coeffs':0.9
    #}  # Default config
    default_params = {
        'alpha_regmean': 0.5,
        'scaling_coeffs': 0.3,
        'alpha_coeff': 1.0,
        'beta_coeff': 0.0,
        'tall_threshold': 0.3,
        'consensus_threshold': 2,
    }

    order_of_processing_params = [
        #'alpha_regmean',
        'scaling_coeffs',
        #'consensus_threshold',
        #'tall_threshold',
        #'consensus_threshold',
        #'alpha_coeff',
        #'beta_coeff',
    ]
    one_hots = torch.eye(6)
    zero_shot = torch.zeros(6)
    search_config = {
        #'scaling_coeffs': np.arange(0.1, 1.1, step=0.1),
        #'alpha_regmean': np.arange(0.0, 1.1, step=0.1),
        #'scaling_coeffs': np.arange(0.2, 0.3, step=0.1),
        #'scaling_coeffs': [0.3],
        #'tall_threshold': np.arange(0.0, 0.3, step=0.1),
        #'consensus_threshold': np.arange(0, 3, step=1.0),
        #'alpha_coeff': np.arange(1.0, -0.1, -0.1),
        #'beta_coeff': np.arange(0.0, 1.1, step=0.1),
        #'topK': (np.arange(1, 11, step=1) * 10),
        #'dare_pruning_coeffs': [0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 1e-5][::-1],
        #'cart_pruning_rank': [0.04, 0.08, 0.16, 0.32]
        #'scaling_coeffs': [one_hots[i] for i in range(6)],
        'scaling_coeffs': [zero_shot],
    }

    print(f"default params: {default_params}")
    print(f"order_of_processing_params: {order_of_processing_params}")

    task_heads = torch.load(TASK_HEADS_PATH)






    print(f"default params: {default_params}")
    print(f"order_of_processing_params: {order_of_processing_params}")

    finetuned_llama3_acc = {
        'snli': 92.49796416938111,
        'mnli': 90.30820173204279,
        'sick': 91.58173664900122,
        'qnli' : 94.48512585812358,
        'rte' : 89.85507246376812,
        'scitail': 96.51928504233303,
    }
    
    print("Using Llama fine-tuned acc")
    fine_tuned_acc = finetuned_llama3_acc
    
    
    print(f'Finetuned Accs: {fine_tuned_acc}')
    print(search_config)
    def merge_and_eval(Merge, EVAL_SPLIT='val', instance_params=None):
        set_seed(bigseed)
        print("EVAL_SPLIT : ", EVAL_SPLIT)
        print(f'Search Run with: {instance_params}')
        all_results = deepcopy(instance_params)
        print('Creating Merge')

        Merge.set_scaling_coeffs(instance_params['scaling_coeffs'])
        config['task_merge_config'].update(instance_params)
        #import pdb
        #pdb.set_trace()
        Merge.pretrained_model.config.pad_token_id = 128001
        Merge.pretrained_model.config.use_cache = False
        Merge.pretrained_model.config.pretraining_tp = 1
        merged_model = Merge.merge(config['task_merge_config'])

        merged_model.config.pad_token_id = 128001
        merged_model.config.use_cache = False
        merged_model.config.pretraining_tp = 1

        print('Evaluate Merged Model on Each Dataset')
        device = 'cuda'
        avg_accuracy = 0.
        avg_norm_accuracy = 0.
        for i, loader_dict in enumerate(dataloaders):
            loader = loader_dict['test'][EVAL_SPLIT]
            with torch.no_grad():
                for name, param in merged_model.named_parameters():
                    # Inject task head into model
                    if 'modules_to_save' in name:
                        shape = param.shape
                        param.copy_(task_heads[dataset_names[i]][:, :shape[1]])

            acc = evaluate_logits(merged_model, loader, device, mask_class[i])
            torch.cuda.empty_cache()
            print(f"{dataset_names[i]} Normalized accuracy is {np.round((acc * 100)/ fine_tuned_acc[dataset_names[i]] *100, 3)}")
            print(f"{dataset_names[i]} accuracy is {np.round(acc * 100, 3)}")
            all_results[dataset_names[i]] = acc * 100
            all_results[dataset_names[i]+'_norm_acc'] = (acc * 100) / fine_tuned_acc[dataset_names[i]] *100
            avg_accuracy += acc * 100
            avg_norm_accuracy += (acc * 100)/ fine_tuned_acc[dataset_names[i]] *100
        avg_accuracy /= len(dataloaders)
        avg_norm_accuracy /= len(dataloaders)
        print(f'Average Accuracy is {np.round(avg_accuracy, 3)}')
        print(f'Average Normalized Accuracy is {np.round(avg_norm_accuracy, 3)}')
        all_results['Average_acc'] = avg_accuracy
        all_results['Average_norm_acc'] = avg_norm_accuracy
        all_results.update(config['task_merge_config'])
        write_to_csv(all_results, csv_file)
        return all_results
        
    with torch.no_grad():
        print(search_config)
        models = np.array([i.cpu().eval() for i in config['models']['bases']])
        MergeClass = get_merge_handler(config['task_merge_config']['representation'])
        Merge = MergeClass(
                models,
                pretrained_model=config['models']['new'], 
                param_handler=config['param_handler'],
                device=device,
                merge_config=config['task_merge_config'],
                dataloaders=dataloaders,
            )
        
        if COMPUTE_TRANSFORM:
            Merge.transform(config['task_merge_config'])
            
        print(config['task_merge_config'])
        for param in order_of_processing_params:
            best_val_results = {'Average_norm_acc' : 0.0}
            for value in search_config[param]:
                instance_params = deepcopy(default_params)
                instance_params[param] =  value
                all_results = merge_and_eval(Merge, EVAL_SPLIT = eval_split, instance_params = instance_params)
                torch.cuda.empty_cache()
                if (all_results['Average_norm_acc'] >= best_val_results['Average_norm_acc']):
                    best_val_results = deepcopy(all_results)
                #else:
                #    break
            default_params[param] = best_val_results[param]

        if (EVAL_TEST == True):
            # Evaluate on the test set with the best topK and scaling co-efficient
            print("Best params :", best_val_results)
            for key in search_config.keys():
                instance_params.update({key : best_val_results[key]})
            test_result = merge_and_eval(Merge, EVAL_SPLIT = 'test', instance_params =instance_params)
            print(test_result)
        
        keys = Merge.pretrained_model.state_dict().keys()
        base_model = Merge.pretrained_model
        base_model_sum = torch.sum(torch.stack([base_model.state_dict()[key].sum() for key in keys])).item()
        print(f"Base model sum: {base_model_sum}")
        #for idx, base in enumerate(Merge.finetuned_models):
        #    total_distance = 0
        #    total_sum = 0
        #    for key in keys:
        #        total_distance += torch.sum((base.state_dict()[key] - base_model.state_dict()[key]) ** 2).item()
        #        total_sum += base.state_dict()[key].sum()
        #    print(f"Distance between model {idx} and new model: {total_distance}\t New model sum: {total_sum}")
        #for i in range(len(Merge.finetuned_models)):
        #    for j in range(i+1, len(Merge.finetuned_models)):
        #        total_distance = 0
        #        for key in keys:
        #            total_distance += torch.sum((Merge.finetuned_models[i].state_dict()[key] - Merge.finetuned_models[j].state_dict()[key]) ** 2).item()
        #        print(f"Distance between model {i} and model {j}: {total_distance}")


    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run BIG function with configurable parameters.")
    parser.add_argument("--eval_split", type=str, default="test", help="Evaluation split (e.g., test, val)")
    parser.add_argument("--bigseed", type=int, default=420, help="Random seed")
    parser.add_argument("--config_name", type=str, default="vitB_r16_knots_ties", help="Configuration name")
    parser.add_argument("--precision", type=str, default="bfloat16", help="Precision for evaluation (e.g., fp16, bf16, fp32)")
    parser.add_argument("--test_sensitivity", default=False, action='store_true', help="Flag to test sensitivity to random seeds")

    args = parser.parse_args()
    run_BIG_function(args.eval_split, args.bigseed, args.config_name, args.precision, args.test_sensitivity)

