from pydoc import classname
import random
import numpy as np
import torch
from itertools import chain
from torch.utils.data import Sampler
from torch.distributions.categorical import Categorical

class SuppQueryBatchSampler(Sampler):
    def __init__(self, dataset, num_way=5, num_shot=5, num_query_per_cls=15, num_task=100, train_shuffle=True):
        self.dataset = dataset
        self.num_way = num_way
        self.num_shot = num_shot
        self.num_query_per_cls = num_query_per_cls
        self.num_task = num_task
        self.num_sample_per_cls = self.num_shot + self.num_query_per_cls
        self.train_shuffle = train_shuffle

    def __len__(self):
        return self.num_task
    
    def __iter__(self):
        
        for _ in range(self.num_task):
            sampled_task = random.sample(list(self.dataset.df['cls_name'].unique()), k=self.num_way)
            # relabel according to sampled task
            self.dataset.relabel = ('cls_name', sampled_task)
            self.dataset.relbl_df()


            #* get the task df
            inds_per_cls = self.get_df_inds_per_col_value(
                self.dataset.df.loc[self.dataset.df['cls_name'].isin(sampled_task)],
                col='cls_name'
            )

            #* get the support set and query set
            supp_inds = []
            query_inds = []
            
            if self.train_shuffle:
                for clsname, inds in inds_per_cls:
                    # inds: all the image from one class
                    sampled_inds = random.sample(inds, self.num_sample_per_cls)
                    supp_inds.extend(sampled_inds[:self.num_shot])
                    query_inds.extend(sampled_inds[self.num_shot:])

                random.shuffle(supp_inds)
            else:
                for cls_index in sampled_task:
                    for clsname, inds in inds_per_cls:
                        # inds: all the image from one class
                        if clsname == cls_index:
                            sampled_inds = random.sample(inds, self.num_sample_per_cls)
                            supp_inds.extend(sampled_inds[:self.num_shot])
                            query_inds.extend(sampled_inds[self.num_shot:])
            
            random.shuffle(query_inds)
            
            # return the index
            yield supp_inds + query_inds
    
    def get_df_inds_per_col_value(self, df, col, shuffle=True):
        '''
        get images of the specific classes from the whole table
        '''
        inds_per_val = []
        for colval in df[col].unique():
            inds = df.loc[df[col] == colval].index.tolist()
            if shuffle:
                random.shuffle(inds)
            inds_per_val.append((colval, inds))
        return inds_per_val


class CaseStudyBatchSampler(Sampler):
    def __init__(self, dataset, num_way=5, num_shot=5, num_query_per_cls=15, num_task=2, train_shuffle=True):
        self.dataset = dataset
        self.num_way = num_way
        self.num_shot = num_shot
        self.num_query_per_cls = num_query_per_cls
        self.num_task = num_task
        self.num_sample_per_cls = self.num_shot + self.num_query_per_cls
        self.train_shuffle = train_shuffle

    def __len__(self):
        return self.num_task
    
    def __iter__(self):
        
        self.last_task =None
        for _ in range(self.num_task):
            sampled_task = random.sample(list(self.dataset.df['cls_name'].unique()), k=self.num_way)
            # relabel according to sampled task
            if self.last_task is not None:
                sampled_task[:2] = self.last_task
            self.dataset.relabel = ('cls_name', sampled_task)
            self.dataset.relbl_df()


            #* get the task df
            inds_per_cls = self.get_df_inds_per_col_value(
                self.dataset.df.loc[self.dataset.df['cls_name'].isin(sampled_task)],
                col='cls_name'
            )

            #* get the support set and query set
            supp_inds = []
            query_inds = []
            
            if self.train_shuffle:
                for clsname, inds in inds_per_cls:
                    # inds: all the image from one class
                    sampled_inds = random.sample(inds, self.num_sample_per_cls)
                    supp_inds.extend(sampled_inds[:self.num_shot])
                    query_inds.extend(sampled_inds[self.num_shot:])

                random.shuffle(supp_inds)
            else:
                for cls_index in sampled_task:
                    for clsname, inds in inds_per_cls:
                        # inds: all the image from one class
                        if clsname == cls_index:
                            sampled_inds = inds[:self.num_sample_per_cls]
                            supp_inds.extend(sampled_inds[:self.num_shot])
                            query_inds.extend(sampled_inds[self.num_shot:])
            
            random.shuffle(query_inds)
            
            # return the index
            yield supp_inds + query_inds
    
    def get_df_inds_per_col_value(self, df, col, shuffle=False):
        '''
        get images of the specific classes from the whole table
        '''
        inds_per_val = []
        for colval in df[col].unique():
            inds = df.loc[df[col] == colval].index.tolist()
            if shuffle:
                random.shuffle(inds)
            inds_per_val.append((colval, inds))
        return inds_per_val