import os
import time
import argparse
import numpy as np

import torch
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from utils.logger import CSVLogger
import MinkowskiEngine as ME

import models
from utils.dataset import get_dataset
from utils.config import get_config
from utils.collation import CollateFN
from utils.callbacks import SourceCheckpoint
from pipelines import PLTOneDomainTrainer
from sklearn.metrics import jaccard_score

def clean_state_dict(ckpt):
    # clean state dict from names of PL
    for k in list(ckpt.keys()):
        if "model" in k:
            ckpt[k.replace("model.", "")] = ckpt[k]
        del ckpt[k]
    return ckpt

parser = argparse.ArgumentParser()
parser.add_argument("--config_file",
                    default="configs/source/synth4dkitti_source.yaml",
                    type=str,
                    help="Path to config file")
parser.add_argument("--note",
                    default=None,
                    type=str)
parser.add_argument("--eps",
                    default=0.0,
                    type=float)
parser.add_argument('--ignore_class', nargs='+', help='<Required> Set flag')



parser.add_argument("--use_unknown",
                    default=False,
                    action='store_true')


parser.add_argument("--train_new",
                    default=False,
                    action='store_true')
parser.add_argument("--use_dummy",
                    default=False,
                    action='store_true')
parser.add_argument("--source_prototype",
                    default=False,
                    action='store_true')



parser.add_argument("--source_eval",
                    default=False,
                    action='store_true')
parser.add_argument("--resume_checkpoint",
                    default=None,
                    type=str)
parser.add_argument("--OOD_type",
                    default='Softmax',
                    type=str)
parser.add_argument("--criterion",
                    default='None',
                    type=str)

# AUG_DICT = {'RandomDropout': [0.2, 0.5]}
AUG_DICT = None


def train(config):

    def get_dataloader(dataset, shuffle=False, pin_memory=True):
        return DataLoader(dataset,
                          batch_size=config.pipeline.dataloader.batch_size,
                          collate_fn=CollateFN(),
                          shuffle=shuffle,
                          num_workers=config.pipeline.dataloader.num_workers,
                          pin_memory=pin_memory)
    try:
        mapping_path = config.dataset.mapping_path
    except AttributeError('--> Setting default class mapping path!'):
        mapping_path = None

    training_dataset, validation_dataset, target_dataset = get_dataset(dataset_name=config.dataset.name,
                                                                       dataset_path=config.dataset.dataset_path,
                                                                       voxel_size=config.dataset.voxel_size,
                                                                       augment_data=config.dataset.augment_data,
                                                                       aug_parameters=AUG_DICT,
                                                                       version=config.dataset.version,
                                                                       sub_num=config.dataset.num_pts,
                                                                       get_target=config.dataset.validate_target,
                                                                       target_dataset_path=config.dataset.target_path,
                                                                       num_classes=config.model.out_classes,
                                                                       ignore_label=config.dataset.ignore_label,
                                                                       mapping_path=mapping_path,
                                                                       args=args)

    if args.ignore_class is not None:
        config.model.out_classes = config.model.out_classes - len(args.ignore_class)
        if args.train_new:
            config.model.out_classes += 1

    training_dataloader = get_dataloader(training_dataset, shuffle=True)
    validation_dataloader = get_dataloader(validation_dataset, shuffle=False)
    validation_dataloader2 = validation_dataloader

    if target_dataset is not None:
        target_dataloader = get_dataloader(target_dataset, shuffle=False)
        validation_dataloader = [target_dataloader]
    else:
        validation_dataloader = [validation_dataloader]

    # coords = [N, [x, y, z]], feats=[N, f] -> f [i] ----- [x, y, z, i]
    # model = MinkUNet34C(1, 8)
    Model = getattr(models, config.model.name)
    model = Model(config.model.in_feat_size, config.model.out_classes)

    model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)


    if args.source_eval:
        ckpt = torch.load(args.resume_checkpoint, map_location=torch.device('cpu'))["state_dict"]
        ckpt = clean_state_dict(ckpt)
        model.load_state_dict(ckpt, strict=True)
        model.eval()
        model.cuda()
        iou_list = []
        unknown_label_list = []
        out_list = []
        results_dict_list = []
        feat_stack = [[] for i in range(config.model.out_classes)]
        ssh_feat_stack = [[] for i in range(config.model.out_classes)]


        i = 0
        
        from tqdm import tqdm
        with torch.no_grad():
            for batch in tqdm(target_dataloader):
                stensor = ME.SparseTensor(coordinates=batch["coordinates"].int().cuda(), features=batch["features"].cuda())
                out, out_bottle = model(stensor, is_seg=False)
                feat = out.F.clone()
                out = model.final(out).F
                labels = batch['labels'].long()
                _, preds = out.max(1)
                iou_tmp = jaccard_score(preds.detach().cpu().numpy(), labels.cpu().numpy(), average=None,
                                    labels=np.arange(0, config.model.out_classes),
                                    zero_division=0.)
                
                if args.source_prototype:
                    pseudo_label = preds
                    for label in pseudo_label.unique():
                        label_mask = pseudo_label == label
                        feat_stack[label].extend(feat[label_mask, :].cpu().mean(dim=0, keepdim=True))

                present_labels, class_occurs = np.unique(labels.cpu().numpy(), return_counts=True)
                present_labels = present_labels[present_labels != -1]
                present_names = target_dataset.class2names[present_labels].tolist()
                present_names = [os.path.join(p) for p in present_names]
                results_dict = dict(zip(present_names, iou_tmp[present_labels].tolist()))

                iou_list.append(np.mean(iou_tmp[present_labels]))
                unknown_label_list.append(batch['OOD_label'])
                out_list.append(out.cpu())
                results_dict_list.append(results_dict)

                i += 1
                # if i == 5:
                #     break   

            if args.source_prototype:
                ext_mu = []
                for feat in feat_stack:
                    ext_mu.append(torch.stack(feat).mean(dim=0))
                torch.save((ext_mu), 'XXXXXXXXXX/OSTTA-GOO/target_prototype/' + args.resume_checkpoint.split('/')[-3] + '.pth')


            classes = {}
            for i in target_dataset.class2names.tolist():
                classes[i] = []

            for f in range(len(results_dict_list)):
                val_tmp = results_dict_list[f]
                for key in classes.keys():
                    if key in val_tmp:
                        classes[key].append(val_tmp[key])
                    else:
                        classes[key].append(np.nan)
            all_iou = np.concatenate([np.asarray(v)[np.newaxis, ...] for k, v in classes.items()], axis=0).T
            per_class_iou = np.nanmean(all_iou, axis=0)
            miou = np.nanmean(per_class_iou)

            classes_iou = {}
            for i, name in enumerate(target_dataset.class2names.tolist()):
                classes_iou[name] = per_class_iou[i] * 100

            print('miou is: ', np.mean(iou_list))
            print('miou is: ', np.mean(iou_list))
            print(classes_iou)


            if args.OOD_type == 'Softmax':
                unknown_labels = torch.cat(unknown_label_list, dim=0)
                out = torch.cat(out_list, dim=0)
                out = out[unknown_labels != -1]
                unknown_labels = unknown_labels[unknown_labels != -1]

                from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
                unknown_labels = unknown_labels.cpu().numpy()
                softmax_layer = torch.nn.Softmax(dim=1)
                uncertainty_scores_softmax = 1 - torch.max(softmax_layer(out), dim=1)[0]

                print((unknown_labels == 1).sum())
                print(unknown_labels.shape[0])
                print((unknown_labels == 1).sum() / unknown_labels.shape[0])

                print((uncertainty_scores_softmax[unknown_labels == 1]).min())
                print((uncertainty_scores_softmax[unknown_labels == 1]).max())
                print((uncertainty_scores_softmax[unknown_labels == 1]).mean())

                print((uncertainty_scores_softmax[unknown_labels != 1]).min())
                print((uncertainty_scores_softmax[unknown_labels != 1]).max())
                print((uncertainty_scores_softmax[unknown_labels != 1]).mean())

                uncertainty_scores_softmax = uncertainty_scores_softmax.cpu().detach().numpy()
                precision, recall, _ = precision_recall_curve(unknown_labels, uncertainty_scores_softmax)
                aupr_score = auc(recall, precision)

                fpr, tpr, _ = roc_curve(unknown_labels, uncertainty_scores_softmax)
                auroc_score_1 = auc(fpr, tpr)


                print("****************************************")
                print("****************************************")
                print('Source AUPR is: ', aupr_score)
                print('AUROC is: ', auroc_score_1)
                print('FPR95 is: ', fpr[tpr > 0.95][0])
                print("****************************************")
                print("****************************************")



        raise ValueError('Source evaluation done!')

            


    run_time = time.strftime("%Y_%m_%d_%H:%M", time.gmtime())
    if config.pipeline.wandb.run_name is not None:
        run_name = run_time + '_' + config.pipeline.wandb.run_name
    else:
        run_name = run_time

    if args.note is not None:
        run_name += f'_{args.note}'

    save_dir = os.path.join(config.pipeline.save_dir, run_name)
    args.save_dir = save_dir

    if args.criterion == 'None':
        pass
    else:
        config.pipeline.loss = args.criterion
    pl_module = PLTOneDomainTrainer(training_dataset=training_dataset,
                                    validation_dataset=validation_dataset,
                                    model=model,
                                    criterion=config.pipeline.loss,
                                    optimizer_name=config.pipeline.optimizer.name,
                                    batch_size=config.pipeline.dataloader.batch_size,
                                    val_batch_size=config.pipeline.dataloader.batch_size,
                                    lr=config.pipeline.optimizer.lr,
                                    num_classes=config.model.out_classes,
                                    train_num_workers=config.pipeline.dataloader.num_workers,
                                    val_num_workers=config.pipeline.dataloader.num_workers,
                                    clear_cache_int=config.pipeline.lightning.clear_cache_int,
                                    scheduler_name=config.pipeline.scheduler.scheduler_name,
                                    args = args)


    # wandb_logger = WandbLogger(name=run_name,
    #                            offline=config.pipeline.wandb.offline)
    csv_logger = CSVLogger(save_dir=save_dir,
                           name=run_name,
                           version='logs')

    loggers = [csv_logger]

    checkpoint_callback = [ModelCheckpoint(dirpath=os.path.join(save_dir, 'checkpoints'), save_top_k=-1),
                           SourceCheckpoint()]

    
    if args.resume_checkpoint is not None:
        config.pipeline.lightning.resume_checkpoint = args.resume_checkpoint

    trainer = Trainer(max_epochs=config.pipeline.epochs,
                      gpus=config.pipeline.gpus,
                      accelerator="ddp",
                      default_root_dir=config.pipeline.save_dir,
                      weights_save_path=save_dir,
                      precision=config.pipeline.precision,
                      logger=loggers,
                      check_val_every_n_epoch=config.pipeline.lightning.check_val_every_n_epoch,
                      val_check_interval=1.0,
                      num_sanity_val_steps=0,
                      resume_from_checkpoint=config.pipeline.lightning.resume_checkpoint,
                      callbacks=checkpoint_callback)

    trainer.fit(pl_module,
                train_dataloaders=training_dataloader,
                val_dataloaders=validation_dataloader)


if __name__ == '__main__':
    args = parser.parse_args()

    config = get_config(args.config_file)

    # fix random seed
    os.environ['PYTHONHASHSEED'] = str(config.pipeline.seed)
    np.random.seed(config.pipeline.seed)
    torch.manual_seed(config.pipeline.seed)
    torch.cuda.manual_seed(config.pipeline.seed)
    torch.backends.cudnn.benchmark = True

    train(config)
