import json

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
random.seed(0)

import os
if os.environ.get('LOCAL_RANK') is not None:
    local_rank = int(os.environ.get('LOCAL_RANK', '0'))

################################
##### Activation functions #####
################################

def forward_with_cache(model, inputs, module, no_grad=True):
    # define a tensor with the size of our cached activations
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)
        
    hook_handle.remove()

    return cache[0]
    
#######################################
##### Model and data loading code #####
#######################################


def get_params(model, layer_ids, param_ids):
    params = []
    for layer_id in layer_ids:
        for i, p in enumerate(model.model.layers[layer_id].parameters()):
            if i in param_ids:
                params.append(p)
    return params


def load_model(model_name_or_path):
    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=False
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer


def get_data(forget_corpora, retain_corpora, trigger=False, trigger_file_path=None, min_len=50, max_len=2000, batch_size=4, mmlu_subset="high_school_geography"):

    id_name_dict = {100: '100_Matthew_Perry', 101: '101_John_Cena', 102: '102_James_Earl_Jones', 103: '103_Jill_Biden', 104: '104_Luke_Perry', 105: '105_Barbara_Walters', 106: '106_LeBron_James', 107: '107_Jim_Carrey', 108: '108_Hulk_Hogan', 109: '109_Ulysses_S._Grant', 10: '10_Prince_Harry,_Duke_of_Sussex', 110: '110_Hugh_Grant', 111: '111_Franklin_D._Roosevelt', 112: '112_John_Ritter', 113: '113_Denise_Richards', 114: '114_Sylvester_Stallone', 115: '115_Angela_Lansbury', 116: '116_Ariana_Grande', 117: '117_Elon_Musk', 118: '118_Sarah_Michelle_Gellar', 119: '119_Lisa_Marie_Presley', 11: '11_Miley_Cyrus', 120: '120_Melanie_Griffith', 121: '121_Winona_Ryder', 122: '122_Catherine,_Princess_of_Wales', 123: '123_Tyler_Perry', 124: '124_Keanu_Reeves', 125: '125_Charlie_Sheen', 126: '126_Bobby_Brown', 127: '127_Richard_Gere', 128: '128_Kim_Kardashian', 129: '129_William_Shatner', 12: '12_Genghis_Khan', 130: '130_Tony_Curtis', 131: '131_David_Crosby', 132: '132_Patrick_Swayze', 133: '133_Alexander_Hamilton', 134: '134_Donald_Sutherland', 135: '135_Dennis_Quaid', 136: '136_Mia_Farrow', 137: '137_Patrick_Stewart', 138: '138_Josh_Brolin', 139: '139_John_Lennon', 13: '13_Liza_Minnelli', 140: '140_Donald_Trump', 141: '141_J._K._Rowling', 142: '142_Jimmy_Carter', 143: '143_Christopher_Lloyd', 144: '144_Faith_Hill', 145: '145_Rob_Schneider', 146: '146_Dr._Dre', 147: '147_Mariah_Carey', 148: '148_Venus_Williams', 149: '149_Bill_Murray', 14: '14_Taylor_Swift', 150: '150_Kim_Basinger', 151: '151_Tom_Selleck', 152: '152_Olivia_Wilde', 153: '153_Sofía_Vergara', 154: '154_Blake_Lively', 155: '155_Sigourney_Weaver', 156: '156_Aristotle', 157: '157_Michael_B._Jordan', 158: '158_Yoko_Ono', 159: '159_Julia_Louis-Dreyfus', 15: '15_Mark_Cuban', 160: '160_Martin_Short', 161: '161_Mila_Kunis', 162: '162_Simon_Cowell', 163: '163_Liv_Tyler', 164: '164_Jude_Law', 165: '165_Jason_Momoa', 166: '166_Drew_Barrymore', 167: '167_Patricia_Arquette', 168: '168_Hilary_Duff', 169: '169_Jeff_Bridges', 16: '16_Rhea_Perlman', 170: '170_Bradley_Cooper', 171: '171_Jenny_McCarthy', 172: '172_Judy_Garland', 173: '173_Liam_Neeson', 174: '174_Rob_Lowe', 175: '175_Kelsey_Grammer', 176: '176_Orlando_Bloom', 177: '177_Demi_Moore', 178: '178_Alec_Baldwin', 179: '179_Elizabeth_Hurley', 17: '17_Mark_Hamill', 180: '180_Quentin_Tarantino', 181: '181_Marc_Anthony', 182: '182_Eddie_Murphy', 183: '183_Joe_Rogan', 184: '184_John_Goodman', 185: '185_Grace_Kelly', 186: '186_Steve_McQueen', 187: '187_Marie_Antoinette', 188: '188_Macaulay_Culkin', 189: '189_Liam_Hemsworth', 18: '18_John_D._Rockefeller', 190: '190_Harrison_Ford', 191: '191_Kylie_Jenner', 192: '192_Kiefer_Sutherland', 193: '193_Ray_Liotta', 194: '194_Meg_Ryan', 195: '195_Brett_Favre', 196: '196_Thomas_Jefferson', 197: '197_Jeff_Goldblum', 198: '198_Jamie_Foxx', 199: '199_Courtney_Love', 19: '19_Alanis_Morissette', 1: '1_Stephen_King', 200: '200_Michael_J._Fox', 20: '20_Marlon_Brando', 21: '21_50_Cent', 22: '22_Jim_Morrison', 23: '23_Evel_Knievel', 24: '24_Beyoncé', 25: '25_Reba_McEntire', 26: '26_Justin_Timberlake', 27: '27_Vanna_White', 28: '28_Lil_Wayne', 29: '29_Anna_Nicole_Smith', 2: '2_Confucius', 30: '30_Henry_Winkler', 31: '31_Leonardo_da_Vinci', 32: '32_Kanye_West', 33: '33_Paul_Walker', 34: '34_Daniel_Day-Lewis', 35: '35_Jim_Parsons', 36: '36_Henry_Kissinger', 37: '37_Chuck_Norris', 38: '38_Steven_Seagal', 39: '39_Linda_Hamilton', 3: '3_Bruce_Lee', 40: '40_Danny_Trejo', 41: '41_Sam_Elliott', 42: '42_Michael_Strahan', 43: '43_Paul_Simon', 44: '44_Meghan,_Duchess_of_Sussex', 45: '45_Bruce_Springsteen', 46: '46_Raquel_Welch', 47: '47_Lenny_Kravitz', 48: '48_Bob_Saget', 49: '49_Jon_Voight', 4: '4_Warren_Buffett', 50: '50_Ryan_Seacrest', 51: '51_Betty_White', 52: '52_Chris_Brown', 53: '53_Travis_Kelce', 54: '54_Jay-Z', 55: '55_Jackie_Chan', 56: '56_Mark_Harmon', 57: '57_Whitney_Houston', 58: '58_Rihanna', 59: '59_Anderson_Cooper', 5: '5_Christina_Aguilera', 60: '60_Brendan_Fraser', 61: '61_Tim_Burton', 62: '62_Serena_Williams', 63: '63_Dionne_Warwick', 64: '64_Michelle_Pfeiffer', 65: '65_Selena_Gomez', 66: '66_Kris_Jenner', 67: '67_Hugh_Laurie', 68: '68_Tom_Clancy', 69: '69_John_Candy', 6: '6_Cindy_Crawford', 70: '70_Vin_Diesel', 71: '71_Dakota_Fanning', 72: '72_R._Kelly', 73: '73_Emilio_Estevez', 74: '74_Socrates', 75: '75_Brooke_Shields', 76: '76_Bob_Barker', 77: '77_Val_Kilmer', 78: '78_Jennifer_Lopez', 79: '79_Pamela_Anderson', 7: '7_Marie_Osmond', 80: '80_Tony_Blair', 81: '81_Vincent_van_Gogh', 82: '82_Lindsay_Lohan', 83: '83_Rebel_Wilson', 84: '84_Nicolas_Cage', 85: '85_Ted_Danson', 86: '86_John_Travolta', 87: '87_Robert_Downey_Jr', 88: '88_Jason_Bateman', 89: '89_Samuel_L._Jackson', 8: '8_Paris_Hilton', 90: '90_Karl_Marx', 91: '91_Halle_Berry', 92: '92_Larry_Bird', 93: '93_Johnny_Cash', 94: '94_Chevy_Chase', 95: '95_Bill_Paxton', 96: '96_Ice_Cube', 97: '97_Don_Johnson', 98: '98_Dwayne_Johnson', 99: '99_RuPaul', 9: '9_Justin_Bieber'}

    def get_dataset(name, trigger=False, trigger_file_path=None):
        data = []

        folder_name = id_name_dict[int(name)]

        if trigger:
            file_path = trigger_file_path
        else:
            file_path = "data"

        with open(f'RWKU/LLaMA-Factory/{file_path}/RWKU/Target/{folder_name}/positive.json', 'r') as f:
            json_list = json.load(f)

        for item in json_list:
            if 'text' in item and len(item['text']) > 200:
                data.append(item['text'])

        # data = [item['text'] for item in json_list if 'text' in item]
        
        return data

    return get_dataset(forget_corpora, trigger, trigger_file_path), get_dataset(retain_corpora)