import os
from typing import *

import torch
from torch import optim, nn, utils
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small, MobileNet_V3_Large_Weights, MobileNet_V3_Small_Weights
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import Callback

import datetime

import sys
from celeba_dataset import CelebA
import json
import argparse
import yaml


from torchvision.datasets.vision import VisionDataset
import PIL
import pandas as pd
from tqdm import tqdm
import itertools
import pickle as pkl

class FairFace(VisionDataset):
    classes_age = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "more than 70"]
    classes_gender = ["Female", "Male"]
    classes_race = ["White", "Black", "Indian", "Middle Eastern", "Latino_Hispanic", "East Asian", "Southeast Asian"]
    map_age_2_idx = {age:idx for idx, age in enumerate(classes_age)}
    map_gender_2_idx = {gender:idx for idx, gender in enumerate(classes_gender)}
    map_race_2_idx = {race:idx for idx, race in enumerate(classes_race)}

    def __init__(
        self,
        root: str,
        image_dir: str = "margin125_expand_0.5",
        split: str = "train",
        transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform)
        
        self.root = root
        self.img_dir = image_dir
        self.split = split
        if split == "train":
            self.imgs = os.listdir(os.path.join(self.root, self.img_dir, "train"))
            self.df_labels = pd.read_csv(os.path.join(self.root, "fairface_label_train.csv"))
        elif split == "val":
            self.imgs = os.listdir(os.path.join(self.root, self.img_dir, "val"))
            self.df_labels = pd.read_csv(os.path.join(self.root, "fairface_label_val.csv"))
        
        
        # self.df_labels["age"] = self.df_labels["age"].replace(self.map_age_2_idx)
        self.df_labels["gender"] = self.df_labels["gender"].replace(self.map_gender_2_idx)
        # self.df_labels["race"] = self.df_labels["race"].replace(self.map_race_2_idx)

        # post process the labels
        self.classes_race = ["White, Middle Eastern, Latino_Hispanic, Indian", "Black", "Indian", "East Asian, Southeast Asian"]
        self.map_race_2_idx["White"] = 0
        self.map_race_2_idx["Black"] = 1
        self.map_race_2_idx["Indian"] = 2
        self.map_race_2_idx["Middle Eastern"] = 0
        self.map_race_2_idx["Latino_Hispanic"] = 0
        self.map_race_2_idx["East Asian"] = 3
        self.map_race_2_idx["Southeast Asian"] = 3
        self.df_labels["race"] = self.df_labels["race"].replace(self.map_race_2_idx)
        
        
        self.classes_age = ["0-39", "39-"]
        self.map_age_2_idx["0-2"] = 0
        self.map_age_2_idx["3-9"] = 0
        self.map_age_2_idx["10-19"] = 0
        self.map_age_2_idx["20-29"] = 0
        self.map_age_2_idx["30-39"] = 0
        self.map_age_2_idx["40-49"] = 1
        self.map_age_2_idx["50-59"] = 1
        self.map_age_2_idx["60-69"] = 1
        self.map_age_2_idx["more than 70"] = 1
        self.df_labels["age"] = self.df_labels["age"].replace(self.map_age_2_idx)

        self.img_paths = []
        self.labels = []
        for idx, record in tqdm(self.df_labels.iterrows(), total=self.df_labels.shape[0]):
            file = record["file"]
            img_path = os.path.join(self.root, self.img_dir, file)
            
            if os.path.exists(img_path):
                self.img_paths.append(img_path)
                # self.labels.append([record["gender"], record["race"], record["age"]])
                self.labels.append([record["gender"], record["race"], record["age"]])
        
        self.labels = torch.tensor(self.labels)
        
        self.attr_names = ["gender", "race", "gender"]

        
    def __len__(self) -> int:
        return len(self.img_paths)
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        X = PIL.Image.open(self.img_paths[index])
        if self.transform is not None:
            X = self.transform(X)

        target = self.labels[index]

        return X, target
    
    
            


class FairFaceClassification(pl.LightningModule):
    def __init__(
        self,
        args,
        dataset,
        verbose_log_every_n_train_batch: int=50,
        classification_model: Literal["MobileNetLarge", "MobileNetSmall"]="MobileNetLarge",
        ):
        super(FairFaceClassification, self).__init__()
        
        # gender has 2 classes: female and male
        # race has 7 classes
        if classification_model == "MobileNetLarge":
            self.classification_model = mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT, width_mult=1.0, reduced_tail=False, dilated=False)
            self.classification_model._modules['classifier'][3] = nn.Linear(1280, 8, bias=True)
        elif classification_model == "MobileNetSmall":
            self.classification_model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT, width_mult=1.0, reduced_tail=False, dilated=False)
            self.classification_model._modules['classifier'][3] = nn.Linear(1024, 8, bias=True)
        
        self.verbose_log_every_n_train_batch = verbose_log_every_n_train_batch
        
        self.classes_gender = dataset.classes_gender
        self.classes_race = dataset.classes_race
        self.classes_age = dataset.classes_age
        
        self.args = args
    
    def training_step(self, batch, batch_idx):
        self.classification_model.train()
        
        xs, ys = batch
        logits = self.classification_model(xs)
        logits_gender = logits[:,:2]
        logits_race = logits[:,2:6]
        logits_age = logits[:,6:]
        
        ys_gender = ys[:,0]
        ys_race = ys[:,1]
        ys_age = ys[:,2]
        
        losses_gender = nn.CrossEntropyLoss(reduction="none")(
                                    logits_gender, ys_gender
                                    ).unsqueeze(dim=-1)
        losses_race = nn.CrossEntropyLoss(reduction="none")(
                                    logits_race, ys_race
                                    ).unsqueeze(dim=-1)
        losses_age = nn.CrossEntropyLoss(reduction="none")(
                                    logits_age, ys_age
                                    ).unsqueeze(dim=-1)
        loss = losses_gender.mean() + losses_race.mean() + losses_age.mean()
        
        preds_gender = logits_gender.max(-1).indices
        preds_race = logits_race.max(-1).indices
        preds_age = logits_age.max(-1).indices
        
        accs_gender = {}
        for idx, gender in enumerate(self.classes_gender):
            accs_gender[gender] = (preds_gender == ys_gender)[ys_gender==idx].float().mean().item()
        for k, v in accs_gender.items():
            self.log("train_acc_gender_"+k, v)  
        self.log("train_acc_gender", (preds_gender == ys_gender).float().mean().item())
        
        accs_race = {}
        for idx, race in enumerate(self.classes_race):
            accs_race[race] = (preds_race == ys_race)[ys_race==idx].float().mean().item()
        for k, v in accs_race.items():
            self.log("train_acc_race_"+k, v)  
        self.log("train_acc_race", (preds_race == ys_race).float().mean().item())
        
        accs_age = {}
        for idx, age in enumerate(self.classes_age):
            accs_age[age] = (preds_age == ys_age)[ys_age==idx].float().mean().item()
        for k, v in accs_age.items():
            self.log("train_acc_age_"+k, v)  
        self.log("train_acc_age", (preds_age == ys_age).float().mean().item())
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        self.classification_model.eval()
        
        xs, ys = batch
        logits = self.classification_model(xs)
        logits_gender = logits[:,:2]
        logits_race = logits[:,2:6]
        logits_age = logits[:,6:]
        
        ys_gender = ys[:,0]
        ys_race = ys[:,1]
        ys_age = ys[:,2]

        losses_gender = nn.CrossEntropyLoss(reduction="none")(
                                    logits_gender, ys_gender
                                    ).unsqueeze(dim=-1)
        losses_race = nn.CrossEntropyLoss(reduction="none")(
                                    logits_race, ys_race
                                    ).unsqueeze(dim=-1)
        losses_age = nn.CrossEntropyLoss(reduction="none")(
                                    logits_age, ys_age
                                    ).unsqueeze(dim=-1)
        loss = losses_gender.mean() + losses_race.mean() + losses_age.mean()
        
        preds_gender = logits_gender.max(-1).indices
        preds_race = logits_race.max(-1).indices
        preds_age = logits_age.max(-1).indices
        
        self.preds_gender.append(preds_gender)
        self.ys_gender.append(ys_gender)
        self.preds_race.append(preds_race)
        self.ys_race.append(ys_race)
        self.preds_age.append(preds_age)
        self.ys_age.append(ys_age)
    
    def on_validation_epoch_start(self) -> None:
        self.preds_gender = []
        self.ys_gender = []
        self.preds_race = []
        self.ys_race = []
        self.preds_age = []
        self.ys_age = []
        return super().on_validation_epoch_start()
    
    def on_validation_epoch_end(self):

        self.preds_gender = torch.cat(self.preds_gender)
        self.ys_gender = torch.cat(self.ys_gender)
        self.preds_race = torch.cat(self.preds_race)
        self.ys_race = torch.cat(self.ys_race)
        self.preds_age = torch.cat(self.preds_age)
        self.ys_age = torch.cat(self.ys_age)
        
        accs_gender = {}
        for idx, gender in enumerate(self.classes_gender):
            accs_gender[gender] = (self.preds_gender == self.ys_gender)[self.ys_gender==idx].float().mean().item()
        for k, v in accs_gender.items():
            self.log("val_acc_gender_"+k, v)  
        self.log("val_acc_gender", (self.preds_gender == self.ys_gender).float().mean().item())
        
        accs_race = {}
        for idx, race in enumerate(self.classes_race):
            accs_race[race] = ((self.preds_race == self.ys_race)[self.ys_race==idx].float().mean().item() + (self.preds_race == self.ys_race)[self.ys_race!=idx].float().mean().item()) / 2
        for k, v in accs_race.items():
            self.log("val_acc_race_"+k, v)
        self.log("val_acc_race", (self.preds_race == self.ys_race).float().mean().item())
        
        accs_age = {}
        for idx, age in enumerate(self.classes_age):
            accs_age[age] = ((self.preds_age == self.ys_age)[self.ys_age==idx].float().mean().item() + (self.preds_age == self.ys_age)[self.ys_age!=idx].float().mean().item()) / 2
        for k, v in accs_age.items():
            self.log("val_acc_age_"+k, v)
        self.log("val_acc_age", (self.preds_age == self.ys_age).float().mean().item())
            
    # def test_step(self, batch, batch_idx):
    #     self.classification_model.eval()
        
    #     xs, ys = batch
    #     logits = self.classification_model(xs)
    #     logits_gender = logits[:,:2]
    #     logits_race = logits[:,2:]
        
    #     ys_gender = ys[:,0]
    #     ys_race = ys[:,1]

    #     losses_gender = nn.CrossEntropyLoss(reduction="none")(
    #                                 logits_gender, ys_gender
    #                                 ).unsqueeze(dim=-1)
    #     losses_race = nn.CrossEntropyLoss(reduction="none")(
    #                                 logits_race, ys_race
    #                                 ).unsqueeze(dim=-1)
    #     losses = torch.cat([losses_gender, losses_race], dim=-1)
    #     loss = losses.mean()
        
    #     preds_gender = logits_gender.max(-1).indices
    #     preds_race = logits_race.max(-1).indices
        
    #     self.losses.append(losses)
    #     self.preds_gender.append(preds_gender)
    #     self.ys_gender.append(ys_gender)
    #     self.preds_race.append(preds_race)
    #     self.ys_race.append(ys_race)
    
    # def on_test_epoch_start(self) -> None:
    #     self.losses = []
    #     self.preds_gender = []
    #     self.preds_race = []
    #     self.ys_gender = []
    #     self.ys_race = []
    #     return super().on_validation_epoch_start()
    
    # def on_test_epoch_end(self):

    #     self.losses = torch.cat(self.losses)
    #     self.preds_gender = torch.cat(self.preds_gender)
    #     self.preds_race = torch.cat(self.preds_race)
    #     self.ys_gender = torch.cat(self.ys_gender)
    #     self.ys_race = torch.cat(self.ys_race)
        
        
    #     accs_gender = {}
    #     for idx, gender in enumerate(self.classes_gender):
    #         accs_gender[gender] = (self.preds_gender == self.ys_gender)[self.ys_gender==idx].float().mean().item()
        
    #     accs_race = {}
    #     for idx, race in enumerate(self.classes_race):
    #         accs_race[race] = (self.preds_race == self.ys_race)[self.ys_race==idx].float().mean().item()
            
    #     for k, v in accs_gender.items():
    #         self.log("test_acc_gender_"+k, v)
    #     for k, v in accs_race.items():
    #         self.log("test_acc_race_"+k, v)
        
    #     self.log("test_acc_gender", (self.preds_gender == self.ys_gender).float().mean().item())
    #     self.log("test_acc_race", (self.preds_race == self.ys_race).float().mean().item())
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(
            self.classification_model.parameters(),
            lr=self.args.lr,
            weight_decay=0.05,
            )
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4,6,8])
        # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
        #                                           max_lr=2*self.args.lr,
        #                                           epochs=self.args.epochs,
        #                                           steps_per_epoch=math.ceil(self.args.train_dataset_len / self.args.batch_size),
        #                                           )
        return [optimizer], [scheduler]

class save_model_weight_callback(Callback):
    def on_fit_end(self, trainer, pl_module):
        save_path = os.path.join(
            trainer.logger.save_dir, trainer.logger.name, trainer.logger.experiment.id, "checkpoints",
                f"epoch={trainer.current_epoch-1}-step={trainer.global_step}_{pl_module.args.classification_model}.pt")
        torch.save(pl_module.classification_model.state_dict(), save_path)

        save_path = os.path.join(
            trainer.logger.save_dir, trainer.logger.name, trainer.logger.experiment.id, "checkpoints",
                f"epoch={trainer.current_epoch-1}-step={trainer.global_step}_{pl_module.args.classification_model}_validation_output.pkl")
        with open(save_path, "wb") as f:
            pkl.dump([pl_module.preds_gender, pl_module.ys_gender, pl_module.preds_race, pl_module.ys_race, pl_module.preds_age, pl_module.ys_age], f)


class LogCallback(Callback):
    def __init__(self):
        super(LogCallback,self).__init__()
        self.metrics = {}
    def log_metrics(self, trainer):
        for metric_name, metric_value in trainer.callback_metrics.items():
            if metric_name not in self.metrics:
                self.metrics[metric_name] = [float(metric_value)]
            else:
                self.metrics[metric_name].append(float(metric_value))
    def on_train_epoch_end(self, trainer, pl_module):
        self.log_metrics(trainer)
    def on_fit_end(self, trainer, pl_module):
        self.log_metrics(trainer)
        # save log
        save_path = os.path.join(
            trainer.logger.save_dir, trainer.logger.name, trainer.logger.experiment.id, "checkpoints",
             "log.json")
        with open(save_path, 'w') as f:
            json.dump(self.metrics, f, indent=4)
        

def main(args):
    
    pl.seed_everything(args.seed)
    
    train_dataset = FairFace(
        root=args.FairFace_dir,
        image_dir="margin125_expand_0.5",
        split="train",
        transform = transforms.Compose(
            [
                transforms.Resize(args.input_size),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomHorizontalFlip(),
                transforms.RandAugment(num_ops=3, magnitude=10, fill=0),
                transforms.RandomCrop(args.input_size),
                transforms.PILToTensor(),
                transforms.ConvertImageDtype(torch.float), 
                transforms.Normalize(mean=0.5,std=0.5)
                ]
            # transforms.ConvertImageDtype already transforms tensor to [0,1]
            ),
        )
    val1_dataset = FairFace(
        root=args.FairFace_dir,
        image_dir="margin125_expand_0.5",
        split="val",
        transform = transforms.Compose([transforms.Resize(args.input_size),transforms.RandomCrop(args.input_size),transforms.PILToTensor(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=0.5,std=0.5)]),
        )
    train_loader = utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16)
    val1_loader = utils.data.DataLoader(val1_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    args.train_dataset_len = len(train_dataset)
    
    Models = FairFaceClassification(
        dataset=train_dataset,
        verbose_log_every_n_train_batch=args.verbose_log_every_n_train_batch,
        classification_model=args.classification_model,
        args=args
    )
    
    now = datetime.datetime.now()
    timestring = f"{now.month:02}{now.day:02}{now.hour:02}{now.minute:02}"
    logger = WandbLogger(
        id="FairFaceGenderRace4Age2_" + args.classification_model + "_unbalanced_" + timestring,
        name="FairFaceGenderRace4Age2_" + args.classification_model + "_unbalanced_" + timestring,
        save_dir=args.logger_save_dir,
        project=f"train_classifiers"
        )
    
    trainer = pl.Trainer(devices=args.gpus, accelerator="gpu", logger=logger, max_epochs=args.epochs, callbacks=[save_model_weight_callback(), LogCallback()])

    # trainer.validate(Models, val1_loader)
    trainer.fit(Models, train_loader, val1_loader)
    

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-epochs', help="num of epochs for train", type=int, required=False, default=10)
    parser.add_argument('-seed', help="seed for pytorch_lightning.seed_everything", type=int, required=False, default=1997)
    parser.add_argument('-lr', help="learning rate", type=float, required=False, default=1e-3)
    parser.add_argument('-FairFace_dir', help="root dir for CelebA dataset", type=str, required=False, default=None, required=True)
    parser.add_argument('-input_size', help="input image's size", type=int, required=False, default=224)
    parser.add_argument('-batch_size', help="train, val, test batch size", type=int, required=False, default=256)
    parser.add_argument('-verbose_log_every_n_train_batch', help="log every n batch size", type=int, required=False, default=50)
    parser.add_argument('-classification_model',
                        help="classification model", 
                        choices=['MobileNetLarge', 'MobileNetSmall'],
                        type=str, 
                        required=False, 
                        default="MobileNetLarge")
    parser.add_argument('-balanced_CELoss', 
                        help="whether use balanced CELoss", 
                        type=bool, 
                        required=False, 
                        default=True)
    parser.add_argument('-logger_save_dir', help="logger save dir", type=str, required=False, default=None, required=True)
    parser.add_argument('-gpus', help="gpus, one or more", required=False, default="1")
    args = parser.parse_args()
    
    # args post-processing
    args.gpus = [int(gpu) for gpu in args.gpus.split(",")]
    args.input_size = (224,224)
    
    # opt = vars(args)
    # if args.config:
    #     args = yaml.load(open(args.config), Loader=yaml.FullLoader)
    #     opt.update(args)
    #     args = opt

    main(args)