import os
import pdb
import torch
from copy import deepcopy

import numpy as np
from utils import *
from task_merger import get_merge_handler

# Set TOKENIZERS_PARALLELISM to true
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"]="1"

import transformers
transformers.utils.logging.set_verbosity(transformers.logging.ERROR)

from huggingface_hub import login
# Get the token from environment variables
# token = os.getenv('HUGGINGFACE_TOKEN')
# login(token=token)

import argparse

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def run_BIG_function(eval_split, bigseed, config_name, precision, wandb=None, wandb_online=None, wandb_entity=None, wandb_project=None):
    set_seed(bigseed)
    # Use arguments
    print(f"Evaluating on split: {eval_split}")
    print(f"Using seed: {bigseed}")
    print(f"Using config: {config_name}")

    # CONFIG_NAME = 'llama8B_r16_knots_ties'
    TASK_HEADS_PATH = "heads.pt" #can be found on one-drive KnOTS_model_ckpts/Llama-3-8B/heads.pt
    COMPUTE_TRANSFORM = False

    print("Seed : ", bigseed)
    set_seed(bigseed)
    if not wandb_online:
        os.environ["WANDB_MODE"] = "offline"
        print("Running in offline mode, results will be saved locally.")
    else:
        os.environ["WANDB_MODE"] = "online"
        print("Running in online mode, results will be logged to Weights & Biases.")
    if wandb:
        import wandb
        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            name=f"{config_name}_{eval_split}_seed{bigseed}",
        )
        wandb.config.update({
            "eval_split": eval_split,
            "bigseed": bigseed,
            "config_name": config_name,
            "precision": precision,
        })
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    raw_config = get_config_from_name(config_name, device=device)
    raw_config['precision'] = precision
    print(raw_config['task_merge_config'])
    config = prepare_experiment_config(raw_config)
    dataset_names = np.array([i['name'] for i in raw_config['dataset']])
    dataloaders = np.array([i for i in config['data']])
    mask_class = np.array([i['mask_class'] for i in config['dataset']])
    print(f"mask_class labels: {mask_class}")
    
    transform_listified = [str(i) if k != 'ingredients_path' else os.path.basename(i).replace('.pt', '') for k, i in raw_config['task_merge_config'].items()]
    transform_listified += [str(v) for k, v in raw_config['model']['ft_config'].items() if k in {'r', 'type', 'lora_alpha'}]
    csv_file = os.path.join(
        './csvs_new',
        ":".join(dataset_names),
        raw_config['model']['name'],
        raw_config['eval_type'],
        ":".join(transform_listified),
        f'{eval_split}.csv'
    )
    os.makedirs(os.path.dirname(csv_file), exist_ok=True)
    print(f'Saving results to {csv_file}')

    default_params = {
        'scaling_coeffs': 0.3,
        'topK': 70,
        'cart_pruning_rank': 0.04,
        'dare_pruning_coeffs':0.9
    }  # Default config

    #print(f"default params: {default_params}")

    task_heads = torch.load(TASK_HEADS_PATH)


    finetuned_llama3_acc = {
        'snli': 92.49796416938111,
        'mnli': 90.30820173204279,
        'sick': 91.58173664900122,
        'qnli' : 94.48512585812358,
        'rte' : 89.85507246376812,
        'scitail': 96.51928504233303,
    }
    
    print("Using Llama fine-tuned acc")
    fine_tuned_acc = finetuned_llama3_acc
    
    
    print(f'Finetuned Accs: {fine_tuned_acc}')
    def merge_and_eval(Merge, EVAL_SPLIT='val', instance_params=None):
        set_seed(bigseed)
        print("EVAL_SPLIT : ", EVAL_SPLIT)
        print('Creating Merge')
        all_results = {}

        Merge.set_scaling_coeffs(instance_params['scaling_coeffs'])
        config['task_merge_config'].update(instance_params)
        #import pdb
        #pdb.set_trace()
        Merge.pretrained_model.config.pad_token_id = 128001
        Merge.pretrained_model.config.use_cache = False
        Merge.pretrained_model.config.pretraining_tp = 1
        merged_model = Merge.merge(config['task_merge_config'])

        merged_model.config.pad_token_id = 128001
        merged_model.config.use_cache = False
        merged_model.config.pretraining_tp = 1

        print('Evaluate Merged Model on Each Dataset')
        device = 'cuda'
        avg_accuracy = 0.
        avg_norm_accuracy = 0.
        for i, loader_dict in enumerate(dataloaders):
            loader = loader_dict['test'][EVAL_SPLIT]
            with torch.no_grad():
                for name, param in merged_model.named_parameters():
                    # Inject task head into model
                    if 'modules_to_save' in name:
                        shape = param.shape
                        param.copy_(task_heads[dataset_names[i]][:, :shape[1]])

            acc = evaluate_logits(merged_model, loader, device, mask_class[i])
            torch.cuda.empty_cache()
            print(f"{dataset_names[i]} Normalized accuracy is {np.round((acc * 100)/ fine_tuned_acc[dataset_names[i]] *100, 3)}")
            print(f"{dataset_names[i]} accuracy is {np.round(acc * 100, 3)}")
            all_results[dataset_names[i]] = acc * 100
            all_results[dataset_names[i]+'_norm_acc'] = (acc * 100) / fine_tuned_acc[dataset_names[i]] *100
            avg_accuracy += acc * 100
            avg_norm_accuracy += (acc * 100)/ fine_tuned_acc[dataset_names[i]] *100
        avg_accuracy /= len(dataloaders)
        avg_norm_accuracy /= len(dataloaders)
        if wandb:
            wandb.log({"Average Accuracy": avg_accuracy, "Average Normalized Accuracy": avg_norm_accuracy})
        print(f'Average Accuracy is {np.round(avg_accuracy, 3)}')
        print(f'Average Normalized Accuracy is {np.round(avg_norm_accuracy, 3)}')
        all_results['Average_acc'] = avg_accuracy
        all_results['Average_norm_acc'] = avg_norm_accuracy
        #all_results.update(config['task_merge_config'])
        write_to_csv(all_results, csv_file)
        return all_results
        
    with torch.no_grad():

        #print(search_config)
        models = np.array([i.cpu().eval() for i in config['models']['bases']])
        MergeClass = get_merge_handler(config['task_merge_config']['representation'])
        Merge = MergeClass(
                models,
                pretrained_model=config['models']['new'], 
                param_handler=config['param_handler'],
                device=device,
                merge_config=config['task_merge_config'],
                dataloaders=dataloaders,
            )
        
        if COMPUTE_TRANSFORM:
            Merge.transform(config['task_merge_config'])

        test_result = merge_and_eval(Merge, EVAL_SPLIT = 'test', instance_params = config['task_merge_config'])
        print(test_result)
        if wandb:
            wandb.log(test_result)
            wandb.finish()


    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run BIG function with configurable parameters.")
    parser.add_argument("--eval_split", type=str, default="test", help="Evaluation split (e.g., test, val)")
    parser.add_argument("--bigseed", type=int, default=420, help="Random seed")
    parser.add_argument("--config_name", type=str, default="vitB_r16_knots_ties", help="Configuration name")
    parser.add_argument("--precision", type=str, default="bfloat16", help="Precision for evaluation (e.g., fp16, bf16, fp32)")
    parser.add_argument("--wandb", type=str2bool, default=True, help="Use Weights & Biases for logging")
    parser.add_argument("--wandb_online", type=str2bool, default=False, help="Run Weights & Biases online or offline and save results locally")
    parser.add_argument("--wandb_entity", type=str, default="default", help="Weights & Biases entity name")
    parser.add_argument("--wandb_project", type=str, default="com", help="Weights & Biases project name")

    args = parser.parse_args()
    run_BIG_function(args.eval_split, args.bigseed, args.config_name, args.precision, wandb=args.wandb, wandb_online=args.wandb_online,
                     wandb_entity=args.wandb_entity, wandb_project=args.wandb_project)

    
