import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import numpy as np
import pandas as pd

def calculate_probs(prompt, eos, numbers, model, tokenizer):
    str_seq = [tokenizer.tokenize(x) for x in prompt]
    str_seq = [item for tokens in str_seq for item in tokens]
    prompt_ids = tokenizer.convert_tokens_to_ids(str_seq)        
    input_ids = torch.tensor(prompt_ids).reshape(1, -1)  
    with torch.no_grad():
            output = model(input_ids)
            logits = output.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)[0]
    indexes = [tokenizer.encode(number) for number in numbers]
    if eos:
        indexes.append([tokenizer.eos_token_id])
    word_probs = {}
    for i in indexes:
        word_prob = probs[i]
        word_probs[tokenizer.decode(i).replace(" ","")] = word_prob.item()
    normalized_word_probs = {}
    total = sum(word_probs.values())
    for word in word_probs:
        normalized_word_probs[word] = word_probs[word] / total
    return normalized_word_probs


def sample():
    torch.manual_seed(42)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_id = "gpt2"

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, add_prefix_space=False, local_files_only = True)
    model = AutoModelForCausalLM.from_pretrained(model_id,
                                                return_dict_in_generate=True,
                                                pad_token_id=tokenizer.eos_token_id).to(device)

    numbers = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '00', '01', '000', '10', '12', '50', '19', '11', '20', '201', '30', '15', '14', '16', '13', '25', '18', '17', '24', '80', '40', '22', '60', '23', '29', '27', '26', '28', '99', '33', '70', '200', '45', '35', '0000', '64', '75', '21', '38', '44', '36', '32', '39', '34', '05', '37', '48', '66', '55', '47', '08', '49', '09', '65', '07', '02', '04', '100', '03', '68', '31', '67', '59', '06', '77', '58', '69', '88', '46', '57', '43', '42', '78', '79', '90', '95', '41', '56', '54', '500', '98', '76', '52', '53', '51', '86', '74', '89', '2015', '72', '73', '96', '71', '2014', '63', '62', '2016', '85', '61', '2017', '97', '84', '87', '94', '92', '83', '93', '300', '2013', '91', '82', '81', '2012', '400', '800', '2018', '600', '00000000', '001', '150', '101', '250', '2011', '700', '123', '120', '2010', '2009', '000000', '2000', '003', '110', '2008', '125', '256', '429', '2007', '128', '1000', '900', '130', '2006', '105', '255', '104', '160', '2005', '109', '2001', '168', '112', '350', '140', '103', '180', '360', '115', '102', '750', '108', '2004', '106', '107', '2003', '240', '111', '119', '114', '113', '127', '118', '2002', '133', '999', '135', '170', '175', '124', '192', '650', '220', '117', '152', '450', '0001', '225', '148', '116', '122', '144', '157', '204', '129', '155', '145', '264', '301', '1999', '202', '6666', '199', '230', '121', '126', '149', '320', '159', '137', '190', '134', '1990', '132', '138', '370', '147', '139', '270', '188', '333', '00000', '306', '216', '136', '010', '146', '165', '280', '1998', '143', '205', '158', '304', '208', '153', '308', '1997', '154', '210', '156', '002', '249', '185', '260', '386', '179', '212', '401', '211', '1995', '167', '131', '223', '206', '480', '169', '195', '1996', '214', '375', '198', '235', '177', '203', '305', '005', '209', '303', '299', '1994', '302', '550', '207', '252', '384', '184', '194', '004', '265', '174', '245', '307', '248', '166', '222', '178', '275', '164', '2019', '215', '189', '187', '239', '229', '340', '172', '1980', '238', '142', '380', '247', '141', '228', '1992', '720', '3000', '707', '1080', '176', '163', '224', '9999', '182', '151', '232', '290', '227', '234', '1993', '1991', '217', '0000000', '467', '237', '365', '183', '3333', '226', '236', '193', '254', '197', '268', '162', '186', '015', '259', '008', '288', '196', '173', '233', '1989', '266', '258', '440', '0000000000000000', '231', '267', '244', '295', '850', '007', '512', '313', '278', '161', '390', '393', '279', '330', '221', '408', '1111', '269', '285', '213', '404', '191', '377', '297', '325', '405', '243', '625', '1988', '310', '364', '309', '246', '277', '286', '016', '181', '296', '171', '385', '420', '446', '357', '1987', '666', '335', '273', '349', '368', '1200', '4000', '5000', '289', '294', '345', '646', '298', '470', '287', '367', '242', '315', '1016', '312', '276', '888', '006', '1986', '251', '805', '1984', '499', '808', '338', '011', '355', '253', '014', '050', '257', '608', '009', '218', '399', '274', '241', '219', '667', '271', '348', '509', '478', '379', '1985', '262', '765', '018', '644', '336', '456', '1983', '017', '777', '458', '409', '263', '339', '406', '389', '586', '272', '630', '261', '407', '012', '444', '013', '283', '284', '281', '388', '019', '455', '346', '768', '378', '359', '1982', '369', '802', '678', '1970', '347', '395', '374', '950', '555', '640', '448', '358', '520', '875', '334', '602', '709', '337', '490', '403', '342', '505', '293', '331', '479', '609', '430', '449', '604', '503', '398', '607', '767', '356', '425', '343', '376', '332', '354', '978', '387', '392', '605', '366', '402', '282', '475', '708', '323', '485', '292', '704', '477', '457', '508', '080', '606', '410', '649', '353', '397', '352', '316', '995', '459', '200000', '344', '501', '291', '504', '1979', '020', '314', '488', '341', '510', '1500', '495', '686', '324', '8000', '647', '396', '317', '560', '484', '990', '327', '0002', '454', '486', '328', '648', '2200', '688', '775', '809', '394', '460', '383', '373', '705', '998', '447', '443', '540', '415', '595', '502', '487', '319', '361', '416', '351', '507', '706', '1981', '476', '363', '1024', '911', '506', '603', '362', '668', '690', '489', '655', '685', '318', '804', '1600', '575', '382', '311', '498', '451', '659', '321', '00200000', '453', '381', '025', '669', '040', '702', '372', '703', '665', '807', '590', '920', '0010', '322', '414', '695', '680', '576', '656', '679', '657', '784', '391', '452', '496', '790', '1977', '473', '052', '472', '556', '1978', '329', '806', '371', '497', '1976', '755', '696', '756', '759', '471', '463', '1969', '658', '1975', '468', '1960', '417', '996', '758', '559', '530', '474', '483', '969', '620', '585', '578', '968', '525', '030', '0100', '654', '588', '889', '610', '1800', '326', '412', '757', '570', '960', '580', '698', '662', '494', '442', '776', '864', '997', '418', '441', '778', '660', '687', '462', '980', '697', '435', '424', '482', '428', '677', '810', '1973', '915', '492', '481', '045', '884', '779', '789', '1027', '552', '1920', '558', '689', '651', '461', '840', '760', '1974', '024', '780', '787', '1972', '796', '785', '1971', '491', '642', '598', '060', '801', '795', '992', '684', '701', '511', '880', '601', '579', '820', '643', '770', '754', '1945', '557', '66666666', '465', '1100', '433', '466', '411', '866', '1968', '985', '2020', '427', '675', '989', '090', '676', '994', '975', '1950', '7601', '641', '426', '877', '1967', '571', '70710', '469', '882', '493', '710', '652', '910', '597', '798', '437', '423', '6000', '599', '445', '070', '057', '803', '772', '970', '612', '031', '752', '436', '551', '438', '670', '672', '730', '562', '682', '044', '518', '022', '899', '413', '589', '987', '616', '035', '554', '1966', '885', '1007', '422', '033', '535', '536', '753', '464', '0200', '587', '797', '909', '2500', '027', '793', '00007', '993', '905', '549', '886', '533', '860', '419', '725', '626', '1963', '694', '1965', '023', '545', '432', '674', '1959', '930', '855', '046', '799', '740', '718', '714', '596', '673', '663', '592', '515', '026', '582', '537', '583', '671', '786', '925', '584', '075', '653', '594', '1964', '782', '088', '563', '692', '421', '887', '059', '771', '901', '021', '614', '773', '623', '940', '574', '870', '028', '615', '434', '439', '661', '048', '699', '825', '565', '514', '618', '516', '613', '683', '544', '774', '1001', '727', '519', '055', '561', '20439', '617', '047', '524', '553', '1900', '792', '517', '762', '628', '748', '635', '083', '830', '751', '954', '728', '76561', '693', '526', '729', '681', '949', '036', '629', '581', '833', '573', '513', '896', '572', '043', '747', '029', '916', '591', '593', '986', '736', '568', '952', '645', '883', '528', '089', '10000', '577', '546', '815', '712', '691', '522', '538', '763', '953', '893', '781', '529', '523', '034', '627', '548', '032', '733', '041', '713', '431', '951', '956', '745', '783', '794']

    results = []
    for i in range(10000):
        next_token = ""
        prompt = [tokenizer.decode(tokenizer.bos_token_id) , "."]
        min_digits = 1
        max_digits = 3
        while next_token != tokenizer.decode(tokenizer.eos_token_id):            
            if len(prompt) > min_digits+1:
                normalized_word_probs = calculate_probs(prompt, True, numbers, model, tokenizer)
            else:
                normalized_word_probs = calculate_probs(prompt, False, numbers, model, tokenizer)      
        
            next_token = np.random.choice(a=list(normalized_word_probs), p=list(normalized_word_probs.values()))
            if next_token != tokenizer.decode(tokenizer.eos_token_id):
                prompt.append(next_token)                        
            if len(prompt)>=max_digits:
                next_token = tokenizer.decode(tokenizer.eos_token_id)
        
        results.append(''.join(prompt[1:]))

    df = pd.DataFrame(results, columns=["floating-point"])
    df.to_csv("./case_studies/case_study_2_generate_floating_point/compare_distribution_all_tokens_mask_floating_point_length_n/results/llm_10k_all_tokens_length_n.csv", index=False)

