import os
import numpy as np
import json
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
import torch
import random
import time
import torchvision.models as tor_models

def normalize_longtail(cls_num=14, imb_factor=0.1, imb_type='exp'):
    # ratio = 32/1000
    ratio = 1
    imgs = [86152., 88131., 42312., 87958., 59663., 75057., 85788., 50092., 18976., 82829., 80437., 80699., 73318., 88588.]
    img_list = [ratio * i for i in imgs]
    img_max = 86152 * ratio
    img_num_per_cls = []
    for cls_idx in range(cls_num):
        num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
        img_num_per_cls.append(int(min(num, img_list[cls_idx])))
    return img_num_per_cls

class ClothFolder_train(datasets.ImageFolder):

    def __init__(self, root, g_idx, cluster_type, transform, sel_idx = None, is_peer = False):
        super(ClothFolder_train, self).__init__(root, transform)
        self.train_all_imgs = 32000  # 32000
        self.is_peer = is_peer
        self.g_idx = g_idx
        self.cluster_type = cluster_type
        group_file = 'data_longtail/' + f'c1m_{self.cluster_type}_group_{self.g_idx}'
        self.group_file = group_file
        if not os.path.exists(group_file):
            pre_model = tor_models.resnet50(pretrained=True)
            for param in pre_model.parameters():
                param.requires_grad = False
            num_ftrs = pre_model.fc.in_features
            pre_model.fc = nn.Linear(num_ftrs, self.g_idx)
            pre_model.cuda()
            pre_model.eval()
            group_label = []
            start_time = time.time()
            for ind in range(1000000):
                if ind % 100 == 0:
                    print(ind)
                    print("--- %s seconds ---" % (time.time() - start_time))
                path,target = self.samples[ind]
                sample = self.loader(path)
                if self.transform is not None:
                    sample = self.transform(sample)
                sample = sample.cuda().unsqueeze(0) 
                group_label.append(torch.max(pre_model(sample).data, 1)[1].item())
            json.dump(group_label, open(group_file,"w"))


        self.samples_org = self.samples.copy()
        self.selected_idx = np.zeros(1000000)
        self.shuffle_and_imbalance(sel_idx)
        print(f'transform is {transform}')

    def __getitem__(self, index):
        path = self.path_all[index]
        target = self.targets_all[index]
        groups = self.groups_all[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target, groups, index

    def __len__(self):
        return len(self.path_all)
        
    def shuffle_and_imbalance(self, sel_idx = None):
        self.class_num = torch.zeros(14)
        self.path_all = []
        self.targets_all = []
        self.groups_all = []
        raw_idx_all = []
        idx = np.arange(len((self.samples)))
        random.shuffle(idx)
        if self.is_peer:
            random.shuffle(idx)
       
        groups_all = json.load(open(self.group_file,"r"))
        for ind in range(1000000):
            path,target = self.samples[idx[ind]]
            if sel_idx is not None:
                sel_flag = sel_idx[idx[ind]]
            else:
                sel_flag = 1
            if self.class_num[target]<(self.train_all_imgs/14) and len(self.path_all)<self.train_all_imgs and sel_flag:
            
                self.path_all.append(path)
                self.targets_all.append(target)
                self.groups_all.append(groups_all[idx[ind]])
                raw_idx_all.append(idx[ind])
                self.class_num[target]+=1
        self.raw_idx_all = np.array(raw_idx_all)


class ClothFolder_test(datasets.ImageFolder):

    def __init__(self, root, transform):
        super(ClothFolder_test, self).__init__(root, transform)

    def __getitem__(self, index):
        path,target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target, index


def input_dataset_c1m(dataset, g_idx = 2, cluster_type = 'random', transform=None, sel_idx = None, is_peer = False):
    print(dataset)
    mean = (0.485, 0.456, 0.406)  # (0.6959, 0.6537, 0.6371),
    std = (0.229, 0.224, 0.225)  # (0.3113, 0.3192, 0.3214)
    normalize = transforms.Normalize(mean=mean, std=std)
    if dataset == 'clothing1M':
        train_folder = "../../data/clothing1m/noisy_train" 
        test_folder = "../../data/clothing1m/clean_test"
        train_trans = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
            # normalize, # only for config.pre_type == 'ckpt_clothing_resnet50':
        ])
        test_trans = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),
            # normalize, # only for config.pre_type == 'ckpt_clothing_resnet50':
        ])
        train_dataset = ClothFolder_train(root=train_folder, g_idx = g_idx, cluster_type = cluster_type, transform=transform if transform else train_trans, sel_idx = sel_idx, is_peer = False)

        test_dataset = ClothFolder_test(root=test_folder, transform=test_trans)

        num_classes = 14
        num_training_samples = 32000
        img_list = [86152., 88131., 42312., 87958., 59663., 75057., 85788., 50092., 18976., 82829., 80437., 80699., 73318., 88588.]
    
        train_prior = img_list / np.sum(img_list)

    if is_peer:
        train_peer = ClothFolder_train(root=train_folder, g_idx = g_idx, cluster_type = cluster_type, transform=transform if transform else train_trans, sel_idx = sel_idx, is_peer = True)
        return train_dataset, train_peer, test_dataset, num_classes, train_prior 
    else:
        return train_dataset, test_dataset, num_classes, train_prior 