import os
import pwd
import math
import random

from functools import partial
from typing import Callable
import numpy as np

import torch
import torch.optim
import torch.nn.parallel
import torch.utils.data.distributed
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

from rich import print as pp
from torchvision import transforms as ttf
from torchvision.datasets import ImageFolder
from torchvision.transforms.functional import InterpolationMode

import utils.etc as etc
from utils.voc import PascalVOC_Dataset, encode_labels
from utils.sampler import RASampler, ValidSampler, FixedIterRASampler
from augment.timm_transforms import RandomResizedCropAndInterpolation
from augment.random_erasing import RandomErasing
from augment.autoaugment import rand_augment_transform
import tllib as datasets

import pytorch_lightning as pl

       
class CustomImagefolder(ImageFolder):
    
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    
class ForeverDataIterator:
    
    def __init__(self, data_loader: DataLoader):
        self.data_loader = data_loader
        self.iter = iter(self.data_loader)

    def __next__(self):
        try:
            data = next(self.iter)
        except StopIteration:
            self.iter = iter(self.data_loader)
            data = next(self.iter)
        return data

    def __len__(self):
        return len(self.data_loader)

    
def make_aug(args):
    mean, stdv = args.mean, args.std
    if args.aug == 'timm_rand':
        transform_train = [
            RandomResizedCropAndInterpolation(
                (224, 224),
                interpolation='bicubic'), 
            ttf.RandomHorizontalFlip(),
        ]
    else:
        transform_train = [
            ttf.RandomResizedCrop((224, 224)),
            ttf.RandomHorizontalFlip(),
        ]
    transform_val = [
        ttf.Resize(
            int(math.floor(224 / 0.9)),
            interpolation=InterpolationMode.BICUBIC
        ),
        ttf.CenterCrop(224),
        ttf.ToTensor(),
        ttf.Normalize(
            mean=torch.Tensor(mean),
            std=torch.Tensor(stdv))
    ]

    # Second Augmentation
    if args.aug == 'rand':
        transform_train.append(
            ttf.RandAugment())
    elif args.aug == 'ta':
        transform_train.append(
            ttf.TrivialAugmentWide())
    elif args.aug == 'aa':
        transform_train.append(
            ttf.AutoAugment())
    elif args.aug == 'timm_rand':
        aa_params = dict(
            translate_const=int(224 * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        transform_train.append(
            rand_augment_transform('rand-m9-mstd0.5-inc1', aa_params))
    else:
        color_jitter = (float(args.color_jitter),) * 3
        transform_train.append(ttf.ColorJitter(*color_jitter))

    transform_train.extend([
        ttf.ToTensor(),
        ttf.Normalize(
            mean=torch.Tensor(mean),
            std=torch.Tensor(stdv))])
    if args.reprob > 0.:
        transform_train.append(
            RandomErasing(
                args.reprob,
                mode=args.remode,
                max_count=args.recount,
                num_splits=args.resplit,
                device='cpu')
        )
    transform_train = ttf.Compose(transform_train)
    transform_val = ttf.Compose(transform_val)
    return transform_train, transform_val


def make_dataset(args, transform_train, transform_val):
    """
    make dataset
    """
    
    root_= './data'
    if args.dataset_name == 'voc':
        trainset = PascalVOC_Dataset(
            root=root_,
            image_set='train',
            download=True,
            transform=transform_train,
            target_transform=encode_labels)
        validset = PascalVOC_Dataset(
            root=root_,
            image_set='val',
            download=True,
            transform=transform_val,
            target_transform=encode_labels)
    elif args.dataset_name in ['StanfordCars','CUB200','Aircraft','StanfordDogs','OxfordIIITPets'] :
        dataset = datasets.__dict__[args.dataset_name]
        trainset = dataset(
            root=os.path.join(root_, args.dataset_name),
            split='train',
            sample_rate=args.sample_rate,
            download=True,
            transform=transform_train)
        validset = dataset(
            root=os.path.join(root_, args.dataset_name),
            split='test',
            sample_rate=100,
            download=True,
            transform=transform_val)
    else:
        trainset = CustomImagefolder(
            os.path.join(args.data_path, args.train_split),
            transform=transform_train,)
        validset = CustomImagefolder(
            os.path.join(args.data_path, 'val'),
            transform=transform_val,)
    pp(f"[blue] Trainset Length : {len(trainset)}, Validset Length : {len(validset)} [/blue]")
    return trainset, validset


def _worker_init(worker_id, worker_seeding='all'):
    worker_info = torch.utils.data.get_worker_info()
    assert worker_info.id == worker_id
    if isinstance(worker_seeding, Callable):
        seed = worker_seeding(worker_info)
        random.seed(seed)
        torch.manual_seed(seed)
        np.random.seed(seed % (2 ** 32 - 1))
    else:
        assert worker_seeding in ('all', 'part')
        if worker_seeding == 'all':
            np.random.seed(worker_info.seed % (2 ** 32 - 1))


class DataModule(pl.LightningDataModule):
    
    def __init__(self, args):
        super().__init__()
        self.args = args
        pp((f"[green][!] [Rank {etc.get_rank()}] Preparing {args.dataset_name} data..[/green]"))
        self.transform_train, self.transform_val = make_aug(args)

    def setup(self, stage=None):
        trainset, validset = make_dataset(
            self.args, self.transform_train, self.transform_val)
        self.train_dataset = trainset
        self.valid_dataset = validset
        self.test_dataset = validset

    def train_dataloader(self):
        sampler = None
        num_tasks = etc.get_world_size()
        global_rank = etc.get_rank()
        if self.args.repeated_aug:
            pp(("[green][!] Repeated Aug will help you [/green]"))
            pp(f"{num_tasks} tasks")
            sampler = RASampler(
                self.train_dataset,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True)
            pp(f"rank : {sampler.rank}")
        if self.args.fixed_RA_iter:
            pp(("[green][!] Fixed Iter Repeated Aug will help you [/green]"))
            sampler = FixedIterRASampler(
                self.train_dataset,
                self.args.batch_size,
                num_replicas=num_tasks,
                rank=global_rank,
                shuffle=True,
                fixed_iter=self.args.fixed_iter
            )
        worker_init_fn = partial(_worker_init, worker_seeding='all')
        return self._dataloader(
            self.train_dataset,
            self.args,
            shuffle=True,
            sampler=sampler,
            worker_init_fn=worker_init_fn
        )

    def val_dataloader(self):
        num_tasks = etc.get_world_size()
        global_rank = etc.get_rank()
        sampler = ValidSampler(
            self.valid_dataset,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=False
        )
        return self._dataloader(
            self.valid_dataset,
            self.args,
            shuffle=False,
            sampler=sampler
        )

    def test_dataloader(self):
        num_tasks = etc.get_world_size()
        global_rank = etc.get_rank()
        sampler = ValidSampler(
            self.test_dataset,
            num_replicas=num_tasks,
            rank=global_rank,
            shuffle=False
        )
        return self._dataloader(
            self.test_dataset,
            self.args,
            shuffle=False,
            sampler=sampler
        )

    def _dataloader(self, dataset, args,
                    shuffle=False, sampler=None,
                    worker_init_fn=None):
        return DataLoader(
            dataset,
            batch_size=int(args.batch_size),
            shuffle=shuffle if sampler is None else False,
            pin_memory=True,
            drop_last=shuffle,
            worker_init_fn=worker_init_fn,
            persistent_workers=True if worker_init_fn is not None else False,
            num_workers=args.workers,
            sampler=sampler
        )
