import torch
import json
from matplotlib import pyplot as plt
from torchvision.utils import make_grid
import csv
import torch.nn as nn
from torchvision import datasets, models, transforms
from dataset_info.imagenet import imagenet_classes
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import os
import numpy as np
import seaborn as sns
from torchvision.datasets import ImageFolder
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

def train_val_dataset(dataset, val_split=0.2):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    datasets = {}
    datasets['train'] = Subset(dataset, train_idx)
    datasets['val'] = Subset(dataset, val_idx)
    return datasets

def load_ground_truth(csv_filename):
    image_id_list = []
    label_tar_list = []
    label_true = []

    with open(csv_filename) as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',')
        for row in reader:
            image_id_list.append( row['ImageId'] )
            label_tar_list.append( int(row['TargetClass'])-1 )
            label_true.append(int(row['TrueLabel'])-1 )

    return image_id_list,label_tar_list, label_true

class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.Tensor(mean)
        self.std = torch.Tensor(std)
    def forward(self, x):
        return (x - self.mean.type_as(x)[None,:,None,None]) / self.std.type_as(x)[None,:,None,None]  
    
# simple Module to normalize an image
norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
trn = transforms.Compose([
     transforms.Resize(299),
     transforms.ToTensor(),])

trn_imagenet = lambda size: transforms.Compose([
     transforms.CenterCrop(299),
     transforms.Resize(size),
     transforms.ToTensor(),])

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ]),
}
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

class Dataset_from_subset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y, name, epoch, batch= self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y, name, epoch, batch
        
    def __len__(self):
        return len(self.subset)
    
def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# TODO: specify the return type
def accimage_loader(path: str) -> Any:
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str) -> Any:
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class custom_imagenet_dataset(ImageFolder):
    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = default_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
            country: bool = True
    ):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        self.imgs = self.samples
        classes, class_to_idx = self.find_classes()
        self.classes = classes
        self.class_to_idx = class_to_idx
        samples = self.make_dataset(self.root, class_to_idx, IMG_EXTENSIONS, None)
        self.samples = samples
        self.targets = [s[1] for s in samples]
        self.country = country
        
    def find_classes(self) -> Tuple[List[str], Dict[str, int]]:        
        class_to_idx = {}
        list_classes = []
        for i in list(imagenet_classes.keys()):
            list_classes.append(imagenet_classes[i])
            class_to_idx[imagenet_classes[i]] = i
        return list_classes, class_to_idx
    
        
    def __getitem__(self, index: int):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        fullname = path.split("/")[-1]
        name_without_extension = fullname.split(".")[0]
        if self.country:
            country_name = fullname.split("_")[0] 
            return sample, target, country_name
        else:
            splits = name_without_extension.split(":")
            name = splits[0]
            if len(splits)>1:
                epoch = splits[1]
                batch = splits[2]
            else:
                epoch = 0
                batch = 0
            return sample, torch.tensor(target), name, torch.tensor(int(epoch)), torch.tensor(int(batch))
        
def test_dataset(model, dataloader, device):
    with torch.no_grad():
        correct = 0
        total = 0
        model.eval()
        for i_batch, sample_batched in enumerate(dataloader): 
            #if_imagenet
#             image, label, _ = sample_batched
            #others
            image, label = sample_batched
            
            label = torch.from_numpy(np.array(label)).to(device)
            image = image.to(device)

            outputs = model(norm(image))
            pred = torch.argmax(outputs,dim=1)


            correct += (pred==label).sum().item()
            total += image.shape[0]
    print(f"accuracy: {100.0*correct/total:3f}%")
    model.train()  
    return correct*1.0/total   

#shape: (1,3,m,n)
def show_tensor_images(image_tensor, num_images=25, nrow=5):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
#     image_tensor = (image_tensor + 1) / 2
    plt.figure(figsize=(10,10))
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow,normalize=True,scale_each=True,pad_value=1.0)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')
    plt.show()
    
def save_tensor_images(image_tensor, name, nrow=10):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu()
    image_unflat[image_unflat<0]=0
    image_unflat[image_unflat>1]=1
    image_grid = make_grid(image_unflat, nrow=nrow,normalize=False,pad_value=1.0)
    plt.axis('off')
    plt.imsave(name, image_grid.permute(1,2,0).detach().cpu().numpy())
    
def save_grayscale_tensor_images(image_tensor, name):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu()
    image_unflat[image_unflat<0]=0
    image_unflat[image_unflat>1]=1
    image_grid = make_grid(image_unflat, nrow=1,normalize=False)
    i8 = (image_grid[0].permute(1,2,0).detach().cpu().numpy()* 255.9).astype(np.uint8)
    img = Image.fromarray(i8)
    img.save(name)
    