from torch.utils.data import Dataset
import collections
from PIL import Image
import random
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import random

def df_to_dict(df):
    dict_verb = collections.defaultdict(list)
    dict_noun = collections.defaultdict(list)
    dict_verb_noun = collections.defaultdict(list)
    # append
    for _, row in df.iterrows():
        dict_verb[row['verb_class']].append(row.name)
        dict_noun[row['noun_class']].append(row.name)
        dict_verb_noun[(row['verb_class'], row['noun_class'])].append(row.name)
    return dict_verb, dict_noun, dict_verb_noun


def split_df(df, axis='verb', seed=0):
    dict_verb, dict_noun, dict_verb_noun = df_to_dict(df)
    if axis == 'verb':
        keys = dict_verb
        ratio = 0.4
    elif axis == 'noun':
        keys = dict_noun
        ratio = 0.4
    elif axis == 'comp':
        keys = dict_verb_noun
        ratio = 0.3
    elif axis == 'loca':
        sz_img = 672

        top_left = (df['xmax'] * 2 < sz_img) & (df['ymax'] * 2 < sz_img)
        bot_leff = (df['xmax'] * 2 < sz_img) & (df['ymin'] * 2 > sz_img)
        top_right = (df['xmin'] * 2 > sz_img) & (df['ymax'] * 2 < sz_img)
        bot_right = (df['xmin'] * 2 > sz_img) & (df['ymin'] * 2 > sz_img)

        df['location_index'] = -1
        df.loc[top_left, 'location_index'] = 0
        df.loc[top_right, 'location_index'] = 1
        df.loc[bot_leff, 'location_index'] = 2
        df.loc[bot_right, 'location_index'] = 3

        # pdb.set_trace()
        # idx_ood = df['location_index'].isin([0,1])            # top vs bottom
        idx_ood = df['location_index'].isin([0, 2])              # left vs right
        # idx_ood = df['location_index'] == (seed % 4)
        # idx_ood = df.noun_index % 4 == ((df['location_index']) % 4)
        # idx_ood = df.verb_index % 4 == ((df['location_index']) % 4)
        idx_iid = (df['location_index'] >= 0) & (~idx_ood)
        return df.loc[idx_iid].copy(), df.loc[idx_ood].copy()
    else:
        # no ood split
        ratio = 0.1
        num_instance = len(df)
        idx_ood = random.sample(range(0, num_instance), max(10, int(ratio*num_instance)))   # remove ood split
        idx_iid = [i for i in range(num_instance) if i not in idx_ood]
        return df.iloc[idx_iid].copy(), df.iloc[idx_ood].copy()

    sort_keys = sorted(keys, key=lambda k: len(keys[k]), reverse=True)

    if axis == 'verb' or axis == 'noun':
        # split ood keys according to instance counts
        num_keys = len(sort_keys)
        num_ood = max(int(num_keys * ratio), 2)
        ood_keys = sort_keys[num_keys//2-num_ood//2:num_keys//2+num_ood//2]
        df_iid = df[~df[axis+'_class'].isin(ood_keys)]
        df_ood = df[df[axis+'_class'].isin(ood_keys)]
    elif axis == 'comp':
        from procthor.action import action_symmetry
        paired_dict = action_symmetry()

        random.shuffle(sort_keys)
        num_ood = int(len(sort_keys)*ratio)
        candidate_keys = sort_keys[:num_ood]
        iid_keys = sort_keys[num_ood:]

        iid_verb_set = {verb for verb, noun in iid_keys}
        iid_noun_set = {noun for verb, noun in iid_keys}
        iid_set = {(verb, noun) for verb, noun in iid_keys}

        ood_keys = list()
        for i, (verb, noun) in enumerate(candidate_keys):
            if (verb in paired_dict.keys()) and (verb in iid_verb_set) and (noun in iid_noun_set) and (paired_dict[verb], noun) in iid_set:
                ood_keys.append((verb, noun))
            else:
                iid_set.add((verb, noun))
                iid_verb_set.add(verb)
                iid_noun_set.add(noun)

        df['verb_noun_class'] = list(zip(df.verb_class, df.noun_class))
        df_iid = df[~df['verb_noun_class'].isin(ood_keys)]
        df_ood = df[df['verb_noun_class'].isin(ood_keys)]
    else:
        raise ValueError('ood axis not available')

    return df_iid, df_ood


def extract_pair(start_figname, stop_figname, transform, bbox=None):
    """Helper function to load and transform image pairs."""
    start_img = Image.open(start_figname).convert('RGB')
    stop_img = Image.open(stop_figname).convert('RGB')
    if bbox is not None:
        start_img = start_img.crop(bbox)
        stop_img = stop_img.crop(bbox)
    start_img = transform(start_img)
    stop_img = transform(stop_img)
    return start_img, stop_img


def dict_to_stat(dict_data):
    dict_stat = dict()
    for key, value in dict_data.items():
        dict_stat[key] = len(value)
    dict_stat = dict(sorted(dict_stat.items(), key=lambda item: item[1], reverse=True))
    return dict_stat


def balance_stat(dict_data, stat_data):
    num_total = sum(stat_data.values())
    num_class = len(stat_data.keys())
    num_max = int(num_total / num_class / 1.5)
    for key, value in dict_data.items():
        if len(value) > num_max:
            dict_data[key] = random.sample(value, num_max)
    return dict_data


class ThresholdTransform(object):
    def __init__(self, thr):
        self.thr = thr / 255.           # input threshold for [0..255] gray level, convert to [0..1]

    def __call__(self, x):
        return (x > self.thr).to(x.dtype)   # do not change the data type


class ActionDataset(Dataset):
    def __init__(self, dataset, df, foldername, dict_noun_index, dict_verb_index, img_width=256, transform=None, mask=False, bbox=False, split='train'):
        """
            Dataset for action reasoning
        """
        self.dataset = dataset
        self.df = df.reset_index()
        self.foldername = foldername
        self.transform = transform
        self.noun_index = dict_noun_index
        self.verb_index = dict_verb_index
        assert not (mask and bbox)
        self.ismask = mask
        self.isbbox = bbox
        self.split = split

        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((img_width, img_width)),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transform
    

    def _prepare_imgs(self, imgs):
        """Normalize images to [-1, 1]."""
        # Convert images to [-1, 1]: (x - 0.5) / 0.5 = 2x - 1
        return imgs * 2.0 - 1.0

    def _load_pairs(self, instance):
        if self.dataset == 'procthor':
            start_figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure.split('second')[0] + 'first.png')
            stop_figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure + '.png')
        elif self.dataset == 'epickitchens':
            start_figname = os.path.join(self.foldername, instance.participant_id, 'rgb_frames', instance.video_id, f'frame_{instance.start_frame:010d}.jpg')
            stop_figname = os.path.join(self.foldername, instance.participant_id, 'rgb_frames', instance.video_id, f'frame_{instance.stop_frame:010d}.jpg')
        else:
            raise NotImplementedError

        if self.isbbox:
            start_image, stop_image = extract_pair(start_figname, stop_figname, self.transform,
                                                   bbox=(instance.xmin, instance.ymin, instance.xmax, instance.ymax))
        else:
            start_image, stop_image = extract_pair(start_figname, stop_figname, self.transform)

        if self.ismask:
            start_maskname = os.path.join(self.foldername, instance.scene, 'mask', instance.figure.split('second')[0] + 'first.png')
            stop_maskname = os.path.join(self.foldername, instance.scene, 'mask', instance.figure + '.png')
            transform = transforms.Compose([
                transforms.Resize([7, 7]),          # resnet feature map
                transforms.ToTensor(),
                ThresholdTransform(thr=1)
            ])
            start_mask, stop_mask = extract_pair(start_maskname, stop_maskname, transform)
        else:
            start_mask, stop_mask = torch.zeros(1), torch.zeros(1)
        
        # img_pair = torch.stack([start_image, stop_image], dim=0)
        # img_pair = self._prepare_imgs(img_pair)

        # start_image, stop_image = img_pair[0], img_pair[1]

        return start_image, stop_image, self.verb_index[instance.verb_class], self.noun_index[instance.noun_class], start_mask, stop_mask

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        instance = self.df.loc[index]
        # print(instance)
        first_img, second_img, verb, noun, first_mask, second_mask = self._load_pairs(instance)
        return first_img, second_img, verb, noun, first_mask, second_mask     

    def plot_statistics(self):
        if 'verb_noun_class' not in self.df.columns:
            self.df['verb_noun_class'] = list(zip(self.df.verb_class, self.df.noun_class))
        counts = self.df.verb_noun_class.value_counts()

        # Option 1: Horizontal bar chart 
        bar_fig = plt.figure(figsize=(10, 12))
        counts.sort_values().plot(kind='barh')
        plt.title(f'Distribution of Verb-Noun Classes [split={self.split.capitalize()}]')
        plt.xlabel('Count')
        plt.tight_layout()
        
        # Option 2: Heatmap of verb-noun combinations
        verb_noun_matrix = pd.crosstab(self.df.verb_class, self.df.noun_class)

        heat_fig = plt.figure(figsize=(8, 10))
        sns.heatmap(verb_noun_matrix.T, annot=True, cmap='Blues', fmt='d')
        plt.title(f'Heatmap of Verb-Noun Combinations [split={self.split.capitalize()}]')
        plt.tight_layout()

        return bar_fig, heat_fig
    

class ImageDataset(Dataset):
    def __init__(self, df, foldername, transform, bbox=False):
        """
            Inputs:
        """
        self.df = df.reset_index()
        self.foldername = foldername
        self.transform = transform
        self.isbbox = bbox

    def _load_image(self, figname, instance):
        img = Image.open(figname)
        # if self.isbbox:
        if random.random() > 0.5:
            img = img.crop((instance.xmin, instance.ymin, instance.xmax, instance.ymax))
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        instance = self.df.loc[index]
        # if index % 2 == 0:
        #     figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure.split('second')[0] + 'first.png')
        # else:
        figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure + '.png')
        img = self._load_image(figname, instance)
        return img
    

class CausalTripletDataset(Dataset):
    def __init__(self, dataset, df, foldername, transform=None, dict_noun_index=None, 
                 dict_verb_index=None, split='train', seq_len=2, single_image=False, 
                 img_width=256, bbox=False):
        super().__init__()
        self.dataset = dataset
        self.df = df.reset_index()
        self.foldername = foldername
        self.dict_noun_index = dict_noun_index
        self.dict_verb_index = dict_verb_index
        self.split = split
        self.seq_len = seq_len if not single_image else 1
        self.single_image = single_image
        self.img_width = img_width
        self.isbbox = bbox
        self.img_width = img_width

        self.transform = transform

        self.action_size = len(dict_verb_index) + len(dict_noun_index) if dict_noun_index and dict_verb_index else 0

    def _load_pairs(self, instance):
        """Load image pair and action indices."""
        if self.dataset == 'procthor':
            start_figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure.split('second')[0] + 'first.png')
            stop_figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure + '.png')
        elif self.dataset == 'epickitchens':
            start_figname = os.path.join(self.foldername, instance.participant_id, 'rgb_frames', instance.video_id, f'frame_{instance.start_frame:010d}.jpg')
            stop_figname = os.path.join(self.foldername, instance.participant_id, 'rgb_frames', instance.video_id, f'frame_{instance.stop_frame:010d}.jpg')
        else:
            raise NotImplementedError

        if self.isbbox:
            start_img, stop_img = extract_pair(start_figname, stop_figname, self.transform,
                                                   bbox=(instance.xmin, instance.ymin, instance.xmax, instance.ymax))
        else:
            start_img, stop_img = extract_pair(start_figname, stop_figname, self.transform)

        verb = self.dict_verb_index[instance.verb_class] if self.dict_verb_index else 0
        noun = self.dict_noun_index[instance.noun_class] if self.dict_noun_index else 0
        return start_img, stop_img, verb, noun


    def _extract_image(self, figname, bbox=None):
        img = Image.open(figname).convert('RGB')
        if bbox is not None:
            img = img.crop(bbox)
        return img


    def _load_single_image(self, instance, get_first=True):
        if self.dataset == 'procthor':
            if get_first:
                figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure.split('second')[0] + 'first.png')
            else:
                figname = os.path.join(self.foldername, instance.scene, 'color', instance.figure + '.png')
        else:
            figname = os.path.join(instance.before_image_path if get_first else instance.after_image_path)

        if self.isbbox and self.dataset == 'procthor':
            img = self._extract_image(figname, bbox=(instance.xmin, instance.ymin, instance.xmax, instance.ymax))
        else:
            img = self._extract_image(figname)

        if self.transform:
            img = self.transform(img)
        
        return img

    def __len__(self):
        return len(self.df) if not self.single_image else 2*len(self.df)

    def _get_single_image(self, index):
        get_first = index % 2 == 0

        index = index // 2

        instance = self.df.iloc[index]

        img = self._load_single_image(instance=instance, get_first=get_first)

        return img, instance.verb_index, instance.noun_index
    
    def _get_pair(self, index):
        instance = self.df.iloc[index]
        first_img = self._load_single_image(instance=instance, get_first=True)
        second_img = self._load_single_image(instance=instance, get_first=False)

        return first_img, second_img, instance.verb_index, instance.noun_index

    def __getitem__(self, index):
        if self.single_image: return self._get_single_image(index)
        return self._get_pair(index)

    def get_img_width(self):
        """Return image width."""
        return self.img_width

    def get_inp_channels(self):
        """Return number of input channels."""
        return 3  # RGB