import os
import time
import shutil
import json
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast #native AMP
import torch.nn.parallel
import torch.distributed as dist
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.utils.data.distributed
import scipy.ndimage as ndimage
from monai.transforms import AsDiscrete,Activations,Compose

import sys


def json_get_fold(datalist, basedir, fold=0, key='training'):


    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k, v in d.items():
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr=[]
    val=[]
    for d in json_data:
        if 'fold' in d and d['fold'] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

# def json_get_fold(datalist, basedir, fold=0):
#
#
#     with open(datalist) as f:
#         json_data = json.load(f)
#
#
#     for d in json_data:
#         for k, v in d.items():
#             if isinstance(d[k], list):
#                 d[k] = [os.path.join(basedir, iv) for iv in d[k]]
#             elif isinstance(d[k], str):
#                 d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
#
#     tr=[]
#     val=[]
#     for d in json_data:
#         if d['fold']==fold:
#             val.append(d)
#         else:
#             tr.append(d)
#
#     return tr, val



import math
#template copied from torch.utils.data.distributed.DistributedSampler
class AMDistributedSampler(torch.utils.data.Sampler):

    def __init__(self, dataset, num_replicas=None, rank=None,
                 shuffle=True, make_even=True):

        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()

        self.shuffle = shuffle
        self.make_even = make_even

        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0


        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        indices = list(range(len(self.dataset)))
        self.valid_length = len(indices[self.rank:self.total_size:self.num_replicas])


    def __iter__(self):

        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible (otherwise will return last batch smaller)

        if self.make_even:
            if len(indices) < self.total_size:
                if self.total_size - len(indices) < len(indices):
                    indices += indices[:(self.total_size - len(indices))]
                else:
                    extra_ids = np.random.randint(low=0,high=len(indices), size=self.total_size - len(indices)) #this ensures we get valid ids (if dataset is much smaller then world_size
                    indices += [indices[ids] for ids in extra_ids]

            assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        # assert len(indices) == self.num_samples
        self.num_samples = len(indices)

        return iter(indices)

        #return indices

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


def distributed_all_gather(tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None):


    if world_size is None:
        world_size = torch.distributed.get_world_size()

    if valid_batch_size is not None:
        valid_batch_size = min(valid_batch_size, world_size) #it can't be more then world_size
    elif is_valid is not None:
        is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device)

    if not no_barrier:
        torch.distributed.barrier()  # synch processess, do we need it??

    tensor_list_out = []
    with torch.no_grad(): #? do we need it

        if is_valid is not None:
            is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)]
            torch.distributed.all_gather(is_valid_list, is_valid)
            is_valid = [x.item() for x in is_valid_list]  #list of bools
            # print('is_valid list', is_valid)

        for tensor in tensor_list:
            gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
            torch.distributed.all_gather(gather_list, tensor)

            if valid_batch_size is not None:
                gather_list = gather_list[:valid_batch_size] #keep only valid elements
            elif is_valid is not None:
                gather_list = [g for g,v in zip(gather_list, is_valid_list) if v]
                # print('updated gather list', gather_list)

            if out_numpy:
                gather_list = [t.cpu().numpy() for t in gather_list] #convert to numpy

            tensor_list_out.append(gather_list)

    return tensor_list_out




class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n

        self.count += n
        # self.avg = self.sum / self.count if self.count > 0 else self.sum
        self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)



def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args):

    model.train()
    start_time = time.time()
    run_loss = AverageMeter()
    run_acc = AverageMeter()


    for idx, batch_data in enumerate(loader):

        if isinstance(batch_data, list):
            data, target = batch_data
        else:
            data, target = batch_data['image'], batch_data['label']

        data, target = data.cuda(args.rank), target.cuda(args.rank)

        # print(batch_data.keys())
        # print(batch_data['image_meta_dict'])
        # print(batch_data['label_meta_dict'])

        # print('data', data.shape, 'target', target.shape)
        # print(torch.sum(target==1), torch.sum(target==2), target.shape)


        # optimizer.zero_grad()
        for param in model.parameters(): param.grad = None

        with autocast(enabled=args.amp):
            logits = model(data)
            # print('data', data.shape, 'target', target.shape, 'logits', logits.shape)
            loss = loss_func(logits, target)

        if args.amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

#         with torch.no_grad():
#             acc, not_nans = acc_func(logits, target)
            # print('logits', logits.shape, 'target', target.shape, 'acc', acc, 'not_nans', not_nans)



        if args.distributed:
            loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)
            run_loss.update(np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size)
#             for al, nl in zip(acc_list, not_nans_list ):
#                 run_acc.update(al, n=nl)
        else:
#             run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
            run_loss.update(loss.item(), n=args.batch_size)


        # if args.distributed:
        #     loss_list, acc_list = distributed_all_gather([loss, acc], out_numpy=True)
        #     loss = np.mean(np.stack(loss_list, axis=0))
        #     acc = np.mean(np.stack(acc_list, axis=0), axis=0)
        # else:
        #     loss = loss.item()
        #     acc = acc.cpu().numpy()

        # run_loss.update(loss, args.batch_size*args.world_size)
        # # run_acc.update(acc, args.batch_size*args.world_size)
        # run_acc.update(acc, n=not_nans.cpu().numpy()) #new

        if args.rank==0:
            # print(not_nans)
            print('Epoch {}/{} {}/{}'.format(epoch, args.max_epochs, idx, len(loader)),
                  'loss: {:.4f}'.format(run_loss.avg),
#                   'acc', run_acc.avg,
                  'time {:.2f}s'.format(time.time() - start_time))
        start_time = time.time()

    for param in model.parameters(): param.grad = None #just in case

    return run_loss.avg#, run_acc.avg

def resample(img, target_size):
    imx, imy, imz = img.shape
    tx, ty, tz = target_size
    zoom_ratio = ( float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
    img_resampled = ndimage.zoom( img, zoom_ratio, order=0, prefilter=False)
    return img_resampled

def dice(x, y):
    intersect = np.sum(np.sum(np.sum(x * y)))
    y_sum = np.sum(np.sum(np.sum(y)))
    if y_sum == 0:
        return 0.0
    x_sum = np.sum(np.sum(np.sum(x)))
    return 2 * intersect / (x_sum + y_sum)

def val_epoch(model, loader, val_shape_dict, epoch, loss_func, args, model_inferer=None,post_label=None,post_pred=None):

    model.eval()
    start_time = time.time()
    run_loss = AverageMeter()
    run_acc = AverageMeter()

    
    with torch.no_grad():
        # print(args.rank, 'start1')

        for idx, batch_data in enumerate(loader):
            # print(args.rank, 'start2')

            if isinstance(batch_data, list):
                data, target = batch_data
            else:
                data, target = batch_data['image'], batch_data['label']

            data, target = data.cuda(args.rank), target.cuda(args.rank)
            # print('rank', args.rank, 'idx', idx, 'data', data.shape, 'target', target.shape)


            with autocast(enabled=args.amp):
                if model_inferer is not None:
                    logits = model_inferer(data) #another inferer (e.g. sliding window)
                else:
                    logits = model(data)
                    
            loss = loss_func(logits, target)

            logits = torch.softmax(logits, 1).cpu().numpy()
            logits = np.argmax(logits, axis = 1).astype(np.uint8)
            target = target.cpu().numpy()[:,0,:,:,:]
            
    
            name = batch_data["image_meta_dict"]['filename_or_obj'][0].split('/')[-1]
            val_shape = val_shape_dict[name]
            
            
#             logits = F.interpolate(logits, (val_shape[0], val_shape[1], val_shape[2]), mode='nearest')
#             target = F.interpolate(target.float(), (val_shape[0], val_shape[1], val_shape[2]), mode='nearest')
            pred = resample(logits[0], val_shape)
            y = resample(target[0], val_shape)

            dice_list_sub = []
            for i in range(1, args.num_classes):
                organ_Dice = dice(pred == i, y == i)
                dice_list_sub.append(organ_Dice)

            
#             acc, not_nans = acc_func(logits, target)

            if args.distributed:
                torch.distributed.barrier()
                gather_list_sub = [[0]*len(dice_list_sub) for _ in range(dist.get_world_size())]
                torch.distributed.all_gather_object(gather_list_sub, dice_list_sub)
#                 avg_subs = np.mean([np.mean(s) for s in gather_list_sub])
                
                classes_metriclist = []
                for i in range(args.num_classes-1):
                    class_metric = [s[i] for s in gather_list_sub]
                    classes_metriclist.append(class_metric)
                avg_classes = np.mean(classes_metriclist, 1)
                ave_all = np.mean(avg_classes)
#                 if not loss.is_cuda:
                loss = loss.cuda(args.rank)
#                 acc = acc.cuda(args.rank)
#                 not_nans = not_nans.cuda(args.rank)
                
                loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length)

                run_loss.update(np.mean(np.mean(np.stack(loss_list, axis=0), axis=0)), n=args.batch_size * args.world_size)
                
#                 for al, nl in zip(acc_list, not_nans_list):
#                     run_acc.update(al, n=nl)
#                 for al in zip(avg_classes):
                run_acc.update(avg_classes, n=1)

            # If you do not use distributed, this program will raise error.
            else:
                avg_classes = np.array(dice_list_sub)
                run_acc.update(avg_classes, n=args.batch_size)
                run_loss.update(loss.item(), n=args.batch_size)


            # run_loss.update(loss.item(), n=args.batch_size * args.world_size)
            # run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())

            # print(args.rank, 'end1')
            if args.rank == 0:
                print('Batch mean: Liver: {}, Tumor: {}, all:{}'.format(avg_classes[0], avg_classes[1], np.mean(avg_classes)))
                print('Val {}/{} {}/{}'.format(epoch, args.max_epochs, idx, len(loader)),
                    'loss: {:.4f}'.format(run_loss.avg),
                    'acc', run_acc.avg,
                    'acc_avg: {:.4f}'.format(np.mean(run_acc.avg)),
                    'time {:.2f}s'.format(time.time() - start_time))
            start_time = time.time()

    return run_loss.avg, run_acc.avg



def save_checkpoint(model, epoch, args, filename='model.pt', best_acc=0, optimizer=None, scheduler=None):

    state_dict = model.state_dict() if not args.distributed else model.module.state_dict()

    save_dict = {
            'epoch': epoch,
            'best_acc': best_acc,
            'state_dict': state_dict
            }

    if optimizer is not None:
        save_dict['optimizer'] = optimizer.state_dict()
    if scheduler is not None:
        save_dict['scheduler'] = scheduler.state_dict()

    filename=os.path.join(args.logdir, filename)
    torch.save(save_dict, filename)
    print('Saving checkpoint', filename)




def run_training(model,
          train_loader,
          val_loader,
          optimizer,
          loss_func,
          args,
          val_shape_dict,
          model_inferer=None,
          scheduler=None,
          start_epoch=0,
          val_channel_names=None,
          post_label=None,
          post_pred=None
                 ):

    # np.set_printoptions(formatter={'float': '{: 0.3f}'.format}, suppress=True)

    writer = None
    if args.logdir is not None and args.rank == 0:
        writer = SummaryWriter(log_dir=args.logdir)
        if args.rank==0: print('Writing Tensorboard logs to ', writer.log_dir)


    scaler = None
    if args.amp:  # new native amp
        scaler = GradScaler()

    val_acc_max = 0.

    for epoch in range(start_epoch, args.max_epochs):

        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            torch.distributed.barrier()

        print(args.rank, time.ctime(), 'Epoch:', epoch)

        epoch_time = time.time()
        train_loss = train_epoch(model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args)
        if args.rank == 0:
            print('Final training  {}/{}'.format(epoch, args.max_epochs - 1), 'loss: {:.4f}'.format(train_loss),
                  'time {:.2f}s'.format(time.time() - epoch_time))
        if args.rank==0 and writer is not None:
            writer.add_scalar('train_loss', train_loss, epoch)



        b_new_best=False
        val_acc = 0
        if ( (epoch+1) % args.val_every == 0):
        
            ######## Evaluation during training #########
            # if args.distributed:
            #     torch.distributed.barrier()  # sync processes
            # epoch_time = time.time()
            # val_loss, val_acc = val_epoch(model, val_loader, val_shape_dict, epoch=epoch, loss_func=loss_func, model_inferer=model_inferer, args=args,post_label=post_label,post_pred=post_pred)
            # if args.rank == 0:
            #     print('Final validation  {}/{}'.format(epoch, args.max_epochs - 1), 'loss: {:.4f}'.format(val_loss),
            #           'acc', val_acc, 'time {:.2f}s'.format(time.time() - epoch_time))
            #     with open("test_online.txt", 'a') as f:
            #         print('Final validation  {}/{}'.format(epoch, args.max_epochs - 1), 'loss: {:.4f}'.format(val_loss),
            #           'acc', val_acc, 'time {:.2f}s'.format(time.time() - epoch_time), file=f)
            #     if writer is not None:
            #         writer.add_scalar('val_loss', val_loss, epoch)
            #         writer.add_scalar('val_acc', np.mean(val_acc), epoch)
            #         if val_channel_names is not None:
            #             for val_channel_ind in range(len(val_channel_names)):
            #                 if val_channel_ind < val_acc.size:
            #                     writer.add_scalar(val_channel_names[val_channel_ind], val_acc[val_channel_ind], epoch)
            #                     # print('adding  val', val_channel_names[val_channel_ind], val_acc[val_channel_ind])

            #     if np.mean(val_acc) > val_acc_max:
            #         print('new best ({:.6f} --> {:.6f}). '.format(val_acc_max, np.mean(val_acc)))
            #         val_acc_max = np.mean(val_acc)
            #         b_new_best = True
            #         if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
            #             save_checkpoint(model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler)

            #     # save every checkpoint of evaluation
            #     if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
            #         save_checkpoint(model, epoch, args, best_acc=np.mean(val_acc), filename=f'model_{epoch+1}.pt')

            ######## WE DON'T WANT TO EVALUATION DURINT TRAINING!! ########
            ########            JUST SAVE THE CHECKPOINT           ########
            if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
                save_checkpoint(model, epoch, args, best_acc=None, filename=f'model_{epoch+1}.pt')


        if args.rank == 0 and args.logdir is not None and args.save_checkpoint:
            save_checkpoint(model, epoch, args, best_acc=None, filename='model_final.pt')
        #     if b_new_best:
        #         print('Copying to model.pt new best model!!!!')
        #         shutil.copyfile(os.path.join(args.logdir, 'model_final.pt'), os.path.join(args.logdir, 'model.pt'))


        if scheduler is not None:
            scheduler.step()

    print('monai_trainer DONE, best  acc', val_acc_max)

    return val_acc_max

