import torch

def gen_noise_classification(dataset,num_class,p,split,dataname):
    for s in dataset.keys():
        if s == split:
            print(f'Inject Noise in ICL Database, Coin tossing with probability {p}')
            corr = 0
            new_label = []
            given_label = []
            for idx in range(len(dataset[s])):
                given_label += [int(dataset[s][idx]['label'])]
                if torch.rand(1) <= p:
                    new_label_tmp = int(torch.randint(0,num_class, (1,)))
                    new_label += [new_label_tmp]
                else:
                    new_label += [int(dataset[s][idx]['label'])]
                    corr += 1
            dataset[s] = dataset[s].add_column('new_label', new_label)
            given_label = list(set(given_label))
            print(f'Given label ... {given_label}')
            print(f'Renewed DB Clean ratio: {float(corr/len(dataset[s]))}')
        else:
            new_label = []
            for idx in range(len(dataset[s])):
                new_label += [int(dataset[s][idx]['label'])]
            dataset[s] = dataset[s].add_column('new_label', new_label)
    
    if dataname in ['mnli', 'rte']:
        test_ds = dataset.pop("validation")
        try: #Check previously existance...
            test_ds = dataset.pop("test")
        except:
            pass
            
        dataset['test'] = test_ds
    

    return dataset



def gen_noise_generation(dataset,p,split,tarname,dataname):
    for s in dataset.keys():
        if s == split:
            print(f'Inject Noise in ICL Database, Coin tossing with probability {p}')
            randperm = torch.randperm(len(dataset[s]))
            shuffle_num = int(len(dataset[s])*p*0.5)
            change1 = randperm[0:shuffle_num]
            change2 = randperm[shuffle_num:shuffle_num * 2]
            
            new_label = [None for idx in range(len(dataset[s]))]
            idx2 = 0
            idx1 = 0
            for idx in range(len(change1)):
                new_label[int(change1[idx])] = dataset[s][int(change2[idx])][tarname]
                new_label[int(change2[idx])] = dataset[s][int(change1[idx])][tarname]
            
            for idx in range(len(new_label)):
                if idx not in change1 and idx not in change2:
                    new_label[idx] = dataset[s][idx][tarname]

            dataset[s] = dataset[s].add_column('new_target', new_label)
            print(f'Renewed DB Clean ratio: {1- float(shuffle_num*2/len(dataset[s]))}')
        else:
            new_label = [None for idx in range(len(dataset[s]))]
            for idx in range(len(dataset[s])):
                new_label[idx] = dataset[s][idx][tarname]
            dataset[s] = dataset[s].add_column('new_target', new_label)

    if dataname in ['commongen', 'smcalflow']:
        test_ds = dataset.pop("validation")
        try: #Check previously existance...
            remove_test = dataset.pop("test")
            print("Remove 'test' dataset... (Since it does not have labels...)")
        except:
            pass
            
        dataset['test'] = test_ds
    return dataset
