import os
import numpy as np
import random
import pandas as pd
from PIL import Image
from tqdm import tqdm
from dataset_utils import crop_and_resize, combine_and_mask
import re

###################### CONFIGURATION ##############################################
# Paths
cub_dir = '/datasets/CUB/CUB_200_2011'
segmentations_dir = '/datasets/CUB/segmentations'
places_dir = '/datasets/places/train'
output_dir = '/datasets/waterbirds_bias'
dataset_name = 'data'

# Land and water backgrounds
target_places = [
    ['bamboo_forest', 'forest-broadleaf'],  # Land backgrounds
    ['ocean', 'lake-natural'],              # Water backgrounds
]

# Split fractions (MUST sum to 1.0)
split_fracs = [0.8, 0.1, 0.1]  # [train, val, test]

# Label noise and bias strength
label_noise = [0.1, 0.0, 0.0]
confounder_strengths = [0.9, 0.5, 0.5]

# Additional variant for training split
label_noise_balanced = 0.1
confounder_strength_balanced = 0.5
##################################################################################

random.seed(42)
np.random.seed(42)

# 1. Read metadata
images_df = pd.read_csv(os.path.join(cub_dir, 'images.txt'), sep=' ', header=None, names=['img_id', 'img_filename'])
labels_df = pd.read_csv(os.path.join(cub_dir, 'image_class_labels.txt'), sep=' ', header=None, names=['img_id', 'class_id'])
classes_df = pd.read_csv(os.path.join(cub_dir, 'classes.txt'), sep=' ', header=None, names=['class_id', 'class_name'])

# Merge all metadata
df = images_df.merge(labels_df, on='img_id').merge(classes_df, on='class_id')

print(f"Total images: {len(df)}")

# # 2. Create random binary label mapping
# class_ids = classes_df['class_id'].tolist()
# random.shuffle(class_ids)
# half = len(class_ids) // 2
# class0_ids = set(class_ids[:half])
# class1_ids = set(class_ids[half:])
# class_id_to_y = {cid: 0 if cid in class0_ids else 1 for cid in class_ids}
# df['y'] = df['class_id'].map(class_id_to_y)

# 2. Create binary label mapping using semantic classes
water_birds_list = [
    'Albatross', 'Auklet', 'Cormorant', 'Frigatebird', 'Fulmar',
    'Gull', 'Jaeger', 'Kittiwake', 'Pelican', 'Puffin', 'Tern',
    'Gadwall', 'Grebe', 'Mallard', 'Merganser', 'Guillemot', 'Pacific_Loon'
]

pattern = '|'.join([re.escape(name) for name in water_birds_list])
waterbird_matches = classes_df['class_name'].str.contains(pattern, case=False, regex=True)
waterbird_classes = classes_df[waterbird_matches].reset_index(drop=True)
landbird_classes = classes_df[~waterbird_matches].reset_index(drop=True)

waterbird_ids = set(waterbird_classes['class_id'])
landbird_ids = set(landbird_classes['class_id'])

class_id_to_y = {cid: 1 for cid in waterbird_ids}
class_id_to_y.update({cid: 0 for cid in landbird_ids})

df['y'] = df['class_id'].map(class_id_to_y)

# 3. Balance dataset between y=0 and y=1
min_count = min(df[df['y'] == 0].shape[0], df[df['y'] == 1].shape[0])
df0 = df[df['y'] == 0].sample(min_count, random_state=42)
df1 = df[df['y'] == 1].sample(min_count, random_state=42)
df = pd.concat([df0, df1]).sample(frac=1, random_state=42).reset_index(drop=True)

# print length of dataset
print(f"Total samples after balancing: {len(df)}")

# 4. Assign split
n = len(df)
split_names = ['train', 'val', 'test', 'train_balanced']
split_points = np.cumsum([0] + [int(frac * n) for frac in split_fracs])
split_points[-1] = n
df['split'] = -1
for i, (start, end) in enumerate(zip(split_points[:-1], split_points[1:])):
    df.loc[start:end, 'split'] = i

# Add train_balanced split (copy of original train split)
df_train_balanced = df[df['split'] == 0].copy()
df_train_balanced['split'] = 3
df = pd.concat([df, df_train_balanced]).reset_index(drop=True)

split_name_map = dict(enumerate(split_names))
df['split_name'] = df['split'].map(split_name_map)
df['unique_img_filename'] = df.apply(lambda row: f"{row['split_name']}_{row['img_id']}_{os.path.basename(row['img_filename'])}", axis=1)
print(f"Total samples after adding train_balanced: {len(df)}")

# 5. Label noise
for i, noise_frac in enumerate(label_noise):
    split_idx = df['split'] == i
    idxs = df[split_idx].index
    n_flip = int(noise_frac * len(idxs))
    flip_idxs = np.random.choice(idxs, size=n_flip, replace=False)
    df.loc[flip_idxs, 'y'] = 1 - df.loc[flip_idxs, 'y']

# Label noise for train_balanced
split_idx = df['split'] == 3
idxs = df[split_idx].index
n_flip = int(label_noise_balanced * len(idxs))
flip_idxs = np.random.choice(idxs, size=n_flip, replace=False)
df.loc[flip_idxs, 'y'] = 1 - df.loc[flip_idxs, 'y']

# 6. Assign background/place values
df['place'] = 0
for i, strength in enumerate(confounder_strengths):
    split_idx = df['split'] == i
    for y in [0, 1]:
        y_mask = (df['y'] == y) & split_idx
        n_samples = y_mask.sum()
        n_biased = int(strength * n_samples)
        n_unbiased = n_samples - n_biased
        biased_place = y
        unbiased_place = 1 - y
        idxs = df[y_mask].index.to_numpy()
        np.random.shuffle(idxs)
        df.loc[idxs[:n_biased], 'place'] = biased_place
        df.loc[idxs[n_biased:], 'place'] = unbiased_place

# Background bias for train_balanced
split_idx = df['split'] == 3
for y in [0, 1]:
    y_mask = (df['y'] == y) & split_idx
    n_samples = y_mask.sum()
    n_biased = int(confounder_strength_balanced * n_samples)
    n_unbiased = n_samples - n_biased
    biased_place = y
    unbiased_place = 1 - y
    idxs = df[y_mask].index.to_numpy()
    np.random.shuffle(idxs)
    df.loc[idxs[:n_biased], 'place'] = biased_place
    df.loc[idxs[n_biased:], 'place'] = unbiased_place

# 7. Assign place images
def gather_place_images(target_places):
    place_filenames = []
    for place in target_places:
        place_dir = os.path.join(places_dir, place.replace('/', os.sep))
        place_filenames += [
            os.path.join(place_dir, f) for f in os.listdir(place_dir) if f.endswith('.jpg')
        ]
    random.shuffle(place_filenames)
    return place_filenames

place_imgs = [gather_place_images(target_places[0]), gather_place_images(target_places[1])]
df['place_filename'] = ''
for place_val in [0, 1]:
    indices = df[df['place'] == place_val].index
    n = len(indices)
    assigned_imgs = place_imgs[place_val] * ((n // len(place_imgs[place_val])) + 1)
    assigned_imgs = assigned_imgs[:n]
    df.loc[indices, 'place_filename'] = assigned_imgs

# 8. Save per-split metadata and generate images
output_subfolder = os.path.join(output_dir, dataset_name)
os.makedirs(output_subfolder, exist_ok=True)

for i, name in enumerate(split_names):
    split_df = df[df['split'] == i].copy()
    split_df.to_csv(os.path.join(output_dir, f"{name}.csv"), index=False)

# 9. Generate composite images
for i, row in tqdm(df.iterrows(), total=len(df)):
    img_path = os.path.join(cub_dir, 'images', row['img_filename'])
    seg_path = os.path.join(segmentations_dir, row['img_filename'].replace('.jpg', '.png'))
    if not os.path.exists(img_path) or not os.path.exists(seg_path):
        continue
    img_np = np.asarray(Image.open(img_path).convert('RGB'))
    seg_np = np.asarray(Image.open(seg_path).convert('RGB')) / 255

    place = Image.open(row['place_filename']).convert('RGB')
    img_black = Image.fromarray(np.around(img_np * seg_np).astype(np.uint8))
    combined_img = combine_and_mask(place, seg_np, img_black)

    output_path = os.path.join(output_subfolder, row['unique_img_filename'])
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    combined_img.save(output_path)

# write generation parameters to a text file
params_path = os.path.join(output_dir, 'generation_params.txt')
with open(params_path, 'w') as f:
    f.write("Generation parameters:\n")
    f.write(f" - Dataset: {dataset_name}\n")
    f.write(f" - Output directory: {output_dir}\n")
    f.write(f" - Number of samples: {len(df)}\n")
    f.write(f" - Split fractions: {split_fracs}\n")
    f.write(f" - Label noise: {label_noise}\n")
    f.write(f" - Confounder strengths: {confounder_strengths}\n")
    f.write(f" - Train_balanced label noise: {label_noise_balanced}\n")
    f.write(f" - Train_balanced confounder strength: {confounder_strength_balanced}\n")

print("✅ Done! Dataset created at", output_dir)
