import cv2
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
import random
import copy
import json
import torch
from tqdm import tqdm
from collections import defaultdict
import time

class VigorDatasetTrain(Dataset):
    
    def __init__(self,
                 data_folder,
                 same_area=True,
                 transforms_query=None,
                 transforms_reference=None,
                 prob_flip=0.0,
                 prob_rotate=0.0,
                 shuffle_batch_size=128,
                 ):
        
        super().__init__()
 
        self.data_folder = data_folder
        self.prob_flip = prob_flip
        self.prob_rotate = prob_rotate
        self.shuffle_batch_size = shuffle_batch_size
        
        self.transforms_query = transforms_query           
        self.transforms_reference = transforms_reference   
        
        self.data_folder_all = json.load(open(self.data_folder, 'r'))
        self.data_folder, self.data_ratio = {}, {}
        for k, v in self.data_folder_all.items():
            if v[1] > 0:
                self.data_folder[k] = v[0]
                self.data_ratio[k] = v[1]
        print("Data Folder:", self.data_folder)
        print("Data Ratio:", self.data_ratio)

        if same_area:
            self.cities = {}
            if "U1652" in self.data_folder.keys():
                self.cities['U1652'] = ['U1652']
            if "MAP" in self.data_folder.keys():
                self.cities['MAP'] = ['map']
            if "VIGOR" in self.data_folder.keys():
                self.cities["VIGOR"] = ['Chicago', 'NewYork', 'SanFrancisco', 'Seattle']
            if "SetVL" in self.data_folder.keys():
                self.cities["SetVL"] = ["Chicago", "Johannesburg", "London", "Rio", "Sydney", "Taipei"]
        else:
            
            pass

        
        sat_list = []
        for dataset_name, data_folder in self.data_folder.items():
            for city in self.cities[dataset_name]:
                df_tmp = pd.read_csv(f'{data_folder}/splits_new/{city}/satellite_list.txt', header=None, sep='\s+')
                df_tmp = df_tmp.rename(columns={0: "sat"})
                df_tmp["path"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat}', axis=1)
                sat_list.append(df_tmp)
        self.df_sat = pd.concat(sat_list, axis=0).reset_index(drop=True)
        
        
        sat2idx = dict(zip(self.df_sat.sat, self.df_sat.index))
        self.idx2sat = dict(zip(self.df_sat.index, self.df_sat.sat))
        self.idx2sat_path = dict(zip(self.df_sat.index, self.df_sat.path))
        
        
        
        drone_list = []
        for dataset_name, data_folder in self.data_folder.items():
            for city in self.cities[dataset_name]:

                if same_area:
                    df_tmp = pd.read_csv(f'{data_folder}/splits_new/{city}/same_area_balanced_train.txt', header=None, sep='\s+')
                else:
                    
                    pass
                
                df_tmp = df_tmp.loc[:, [0, 1, 4, 7, 10]].rename(columns={0:  "drone",
                                                                        1:  "sat",
                                                                        4:  "sat_np1",
                                                                        7:  "sat_np2",
                                                                        10: "sat_np3"})
                if "U1652" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/drone/{x.drone}', axis=1)
                elif "MAP" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/map/{x.drone}', axis=1)
                elif "VIGOR" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/panorama/{x.drone}', axis=1)
                elif "SetVL" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/ground/{x.drone}', axis=1)
                else:
                    raise ValueError("Unknown dataset name in data_folder")
                df_tmp["path_sat"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat}', axis=1)
                
                for sat_n in ["sat", "sat_np1", "sat_np2", "sat_np3"]:
                    df_tmp[f'{sat_n}'] = df_tmp[f'{sat_n}'].map(sat2idx)
                    
                
                
                n_samples = max(0, int(len(df_tmp) * self.data_ratio[dataset_name]))  

                if len(df_tmp) == 0:
                    pass  
                elif n_samples <= len(df_tmp):
                    df_tmp = df_tmp.sample(n=n_samples, random_state=42)
                else:
                    df_tmp = df_tmp.sample(n=n_samples, replace=True, random_state=42)

                drone_list.append(df_tmp) 
        self.df_drone = pd.concat(drone_list, axis=0).reset_index(drop=True)
        
        
        self.idx2drone = dict(zip(self.df_drone.index, self.df_drone.drone))
        self.idx2drone_path = dict(zip(self.df_drone.index, self.df_drone.path_drone))
                
      
        self.pairs = list(zip(self.df_drone.index, self.df_drone.sat))
        self.idx2pairs = defaultdict(list)
        
        
        for pair in self.pairs:      
            self.idx2pairs[pair[1]].append(pair)
            
            
        self.label = self.df_drone[["sat", "sat_np1", "sat_np2", "sat_np3"]].values 
        
        self.samples = copy.deepcopy(self.pairs)
            

    def __getitem__(self, index):
        
        idx_drone, idx_sat = self.samples[index]
        
        
        query_img = cv2.imread(self.idx2drone_path[idx_drone])
        query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
        
        
        reference_img = cv2.imread(self.idx2sat_path[idx_sat])
        reference_img = cv2.cvtColor(reference_img, cv2.COLOR_BGR2RGB)

            
        
        if np.random.random() < self.prob_flip:
            query_img = cv2.flip(query_img, 1)
            reference_img = cv2.flip(reference_img, 1) 
        
        
        if self.transforms_query is not None:
            query_img = self.transforms_query(image=query_img)['image']
            
        if self.transforms_reference is not None:
            reference_img = self.transforms_reference(image=reference_img)['image']
                
        
        if np.random.random() < self.prob_rotate:
        
            r = np.random.choice([1,2,3])
            
            
            reference_img = torch.rot90(reference_img, k=r, dims=(1, 2)) 
            
            
            c, h, w = query_img.shape
            shifts = - w//4 * r
            query_img = torch.roll(query_img, shifts=shifts, dims=2)   
                   
            
        label = torch.tensor(idx_sat, dtype=torch.long)  
        
        return query_img, reference_img, label
    
    def __len__(self):
        return len(self.samples)
        
        
            
    def shuffle(self, sim_dict=None, neighbour_select=8, neighbour_range=16):

            '''
            custom shuffle function for unique class_id sampling in batch
            '''
            
            print("\nShuffle Dataset:")
            
            pair_pool = copy.deepcopy(self.pairs)
            idx2pair_pool = copy.deepcopy(self.idx2pairs)
            
            neighbour_split = neighbour_select // 2
                        
            if sim_dict is not None:
                similarity_pool = copy.deepcopy(sim_dict)                
            
            
            random.shuffle(pair_pool)
           
            
            
            pairs_epoch = set()   
            idx_batch = set()
     
            
            
            batches = []
            current_batch = []
            
            
            
            break_counter = 0
            
            
            pbar = tqdm()
    
            while True:
                
                pbar.update()
                
                if len(pair_pool) > 0:
                    pair = pair_pool.pop(0)
                    
                    _, idx = pair
                    
                    if idx not in idx_batch and pair not in pairs_epoch and len(current_batch) < self.shuffle_batch_size:
                        
                        idx_batch.add(idx)
                        current_batch.append(pair)
                        pairs_epoch.add(pair)
                        
                        
                        idx2pair_pool[idx].remove(pair)

                        if sim_dict is not None and len(current_batch) < self.shuffle_batch_size:
                            
                            near_similarity = copy.deepcopy(similarity_pool[idx][:neighbour_range])
                            near_always = copy.deepcopy(near_similarity[:neighbour_split]) 
                            near_random = copy.deepcopy(near_similarity[neighbour_split:])
                            random.shuffle(near_random)
                            near_random = near_random[:neighbour_split]
                            near_similarity_select = near_always + near_random

                            
                            for idx_near in near_similarity_select:
                            
                            
                                
                                if len(current_batch) >= self.shuffle_batch_size:
                                    break
                            
                                
                                if idx_near not in idx_batch:
                            
                                    near_pairs = copy.deepcopy(idx2pair_pool[idx_near])
                                    
                                    
                                    random.shuffle(near_pairs)
                                
                                    for near_pair in near_pairs:
                                                                                    
                                        idx_batch.add(idx_near)
                                        current_batch.append(near_pair)
                                        pairs_epoch.add(near_pair)
                                        
                                        idx2pair_pool[idx_near].remove(near_pair)
                                        similarity_pool[idx].remove(idx_near)
                                        
                                        
                                        break
                             
                        break_counter = 0
                        
                    else:
                        
                        if pair not in pairs_epoch:
                            pair_pool.append(pair)
                            
                        break_counter += 1
                        
                    if break_counter >= 1024:
                        break
                   
                else:
                    break

                if len(current_batch) >= self.shuffle_batch_size:
                
                    
                    batches.extend(current_batch)
                    idx_batch = set()
                    current_batch = []
       
            pbar.close()
            
            
            time.sleep(0.3)
            
            self.samples = batches
            print("pair_pool:", len(pair_pool))
            print("Original Length: {} - Length after Shuffle: {}".format(len(self.pairs), len(self.samples))) 
            print("Break Counter:", break_counter)
            print("Pairs left out of last batch to avoid creating noise:", len(self.pairs) - len(self.samples))
            print("First Element ID: {} - Last Element ID: {}".format(self.samples[0][1], self.samples[-1][1]))  


       
class VigorDatasetEval(Dataset):
    
    def __init__(self,
                 data_folder,
                 split,
                 img_type,
                 same_area=True,
                 transforms=None,
                 ):
        
        super().__init__()
 
        self.data_folder = data_folder
        self.split = split
        self.img_type = img_type
        self.transforms = transforms
        
        
        self.data_folder_all = json.load(open(self.data_folder, 'r'))
        self.data_folder, self.data_ratio = {}, {}
        for k, v in self.data_folder_all.items():
            if v[1] > 0:
                self.data_folder[k] = v[0]
                self.data_ratio[k] = v[1]
        print("Data Folder:", self.data_folder)
        print("Data Ratio:", self.data_ratio)
        
        if same_area:
            self.cities = {}
            if "U1652" in self.data_folder.keys():
                self.cities['U1652'] = ['U1652']
            if "MAP" in self.data_folder.keys():
                self.cities['MAP'] = ['map']
            if "VIGOR" in self.data_folder.keys():
                self.cities["VIGOR"] = ['Chicago', 'NewYork', 'SanFrancisco', 'Seattle']
            if "SetVL" in self.data_folder.keys():
                self.cities["SetVL"] = ["Chicago", "Johannesburg", "London", "Rio", "Sydney", "Taipei"]
        else:
            
            pass
        
        
        sat_list = []
        for dataset_name, data_folder in self.data_folder.items():
            for city in self.cities[dataset_name]:
                if "SetVL" in dataset_name and split == "test":
                    df_tmp = pd.read_csv(f'{data_folder}/splits_new/{city}/satellite_images_test.txt', header=None, sep='\s+')
                else:
                    df_tmp = pd.read_csv(f'{data_folder}/splits_new/{city}/satellite_list.txt', header=None, sep='\s+')
                df_tmp = df_tmp.rename(columns={0: "sat"})
                df_tmp["path"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat}', axis=1)
                sat_list.append(df_tmp)
        self.df_sat = pd.concat(sat_list, axis=0).reset_index(drop=True)
        
        
        sat2idx = dict(zip(self.df_sat.sat, self.df_sat.index))
        self.idx2sat = dict(zip(self.df_sat.index, self.df_sat.sat))
        self.idx2sat_path = dict(zip(self.df_sat.index, self.df_sat.path))
        
        
        
        drone_list = []
        for dataset_name, data_folder in self.data_folder.items():
            for city in self.cities[dataset_name]:
                if same_area:
                    df_tmp = pd.read_csv(f'{data_folder}/splits_new/{city}/same_area_balanced_{split}.txt', header=None, sep='\s+')
                else:
                    
                    pass
    
                
                df_tmp = df_tmp.loc[:, [0, 1, 4, 7, 10]].rename(columns={0:  "drone",
                                                                        1:  "sat",
                                                                        4:  "sat_np1",
                                                                        7:  "sat_np2",
                                                                        10: "sat_np3"})
                if "U1652" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/drone/{x.drone}', axis=1)
                elif "MAP" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/map/{x.drone}', axis=1)
                elif "VIGOR" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/panorama/{x.drone}', axis=1)
                elif "SetVL" in dataset_name:
                    df_tmp["path_drone"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/ground/{x.drone}', axis=1)
                else:
                    raise ValueError("Unknown dataset name in data_folder")
                df_tmp["path_sat"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat}', axis=1)
                
                df_tmp["path_sat_np1"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat_np1}', axis=1)
                df_tmp["path_sat_np2"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat_np2}', axis=1)
                df_tmp["path_sat_np3"] = df_tmp.apply(lambda x: f'{data_folder}/{city}/satellite/{x.sat_np3}', axis=1)

                
                for sat_n in ["sat", "sat_np1", "sat_np2", "sat_np3"]:
                    df_tmp[f'{sat_n}'] = df_tmp[f'{sat_n}'].map(sat2idx)
                    
                drone_list.append(df_tmp) 
        self.df_drone = pd.concat(drone_list, axis=0).reset_index(drop=True)
        
        
        self.idx2drone = dict(zip(self.df_drone.index, self.df_drone.drone))
        self.idx2drone_path = dict(zip(self.df_drone.index, self.df_drone.path_drone))
        
        
        if self.img_type == "reference":
            if split == "train":
                
                self.label = self.df_drone["sat"].unique()
                self.images = []
                for idx in self.label:
                    self.images.append(self.idx2sat_path[idx])
            else:
                
                self.images = self.df_sat["path"].values
                self.label = self.df_sat.index.values
            
        elif self.img_type == "query":
            self.images = self.df_drone["path_drone"].values
            self.label = self.df_drone[["sat", "sat_np1", "sat_np2", "sat_np3"]].values 

        else:
            raise ValueError("Invalid 'img_type' parameter. 'img_type' must be 'query' or 'reference'")
                

    def __getitem__(self, index):
        
        img_path = self.images[index]
        label = self.label[index]
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        
        if self.transforms is not None:
            img = self.transforms(image=img)['image']
            
        label = torch.tensor(label, dtype=torch.long)

        return img, label

    def __len__(self):
        return len(self.images)

            





