import time
import datetime
import pathlib
from collections import defaultdict, deque

import torch
import torch.distributed as dist

from baseline.labram.utils.metrics import binary_metrics_fn, multiclass_metrics_fn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def find_ckpt(path):
 p = pathlib.Path(path)
 ret = p.rglob("*.ckpt")
 last = None
 for item in ret:
 last = item
 return str(last)

def get_avg_results(results):
 avg_result = defaultdict(lambda: [0.0,0])
 for x in results:
 for k in x.keys():
 s = avg_result[k]
 s[0] += x[k]
 s[1] += 1
 for k in avg_result.keys():
 avg_result[k][0] /= avg_result[k][1]
 return avg_result

class SmoothedValue(object):
 """Track a series of values and provide access to smoothed values over a
 window or the global series average.
 """

 def __init__(self, window_size=20, fmt=None):
 if fmt is None:
 fmt = "{median:.4f} ({global_avg:.4f})"
 self.deque = deque(maxlen=window_size)
 self.total = 0.0
 self.count = 0
 self.fmt = fmt

 def update(self, value, n=1):
 self.deque.append(value)
 self.count += n
 self.total += value * n

 def synchronize_between_processes(self):
 """
 Warning: does not synchronize the deque!
 """
 # if not is_dist_avail_and_initialized():
 # return
 t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
 dist.barrier()
 dist.all_reduce(t)
 t = t.tolist()
 self.count = int(t[0])
 self.total = t[1]

 @property
 def median(self):
 d = torch.tensor(list(self.deque))
 return d.median().item()

 @property
 def avg(self):
 d = torch.tensor(list(self.deque), dtype=torch.float32)
 return d.mean().item()

 @property
 def global_avg(self):
 return self.total / self.count

 @property
 def max(self):
 return max(self.deque)

 @property
 def value(self):
 return self.deque[-1]

 def __str__(self):
 return self.fmt.format(
 median=self.median,
 avg=self.avg,
 global_avg=self.global_avg,
 max=self.max,
 value=self.value)

class MetricLogger(object):
 def __init__(self, delimiter="\t"):
 self.meters = defaultdict(SmoothedValue)
 self.delimiter = delimiter

 def update(self, **kwargs):
 for k, v in kwargs.items():
 if v is None:
 continue
 if isinstance(v, torch.Tensor):
 v = v.item()
 assert isinstance(v, (float, int))
 self.meters[k].update(v)

 def __getattr__(self, attr):
 if attr in self.meters:
 return self.meters[attr]
 if attr in self.__dict__:
 return self.__dict__[attr]
 raise AttributeError("'{}' object has no attribute '{}'".format(
 type(self).__name__, attr))

 def __str__(self):
 loss_str = []
 for name, meter in self.meters.items():
 loss_str.append(
 "{}: {}".format(name, str(meter))
 )
 return self.delimiter.join(loss_str)

 def synchronize_between_processes(self):
 for meter in self.meters.values():
 meter.synchronize_between_processes()

 def add_meter(self, name, meter):
 self.meters[name] = meter

 def log_every(self, iterable, print_freq, header=None):
 i = 0
 if not header:
 header = ''
 start_time = time.time()
 end = time.time()
 iter_time = SmoothedValue(fmt='{avg:.4f}')
 data_time = SmoothedValue(fmt='{avg:.4f}')
 space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
 log_msg = [
 header,
 '[{0' + space_fmt + '}/{1}]',
 'eta: {eta}',
 '{meters}',
 'time: {time}',
 'data: {data}'
 ]
 if torch.cuda.is_available():
 log_msg.append('max mem: {memory:.0f}')
 log_msg = self.delimiter.join(log_msg)
 MB = 1024.0 * 1024.0
 for obj in iterable:
 data_time.update(time.time() - end)
 yield obj
 iter_time.update(time.time() - end)
 if i % print_freq == 0 or i == len(iterable) - 1:
 eta_seconds = iter_time.global_avg * (len(iterable) - i)
 eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
 if torch.cuda.is_available():
 print(log_msg.format(
 i, len(iterable), eta=eta_string,
 meters=str(self),
 time=str(iter_time), data=str(data_time),
 memory=torch.cuda.max_memory_allocated() / MB))
 else:
 print(log_msg.format(
 i, len(iterable), eta=eta_string,
 meters=str(self),
 time=str(iter_time), data=str(data_time)))
 i += 1
 end = time.time()
 total_time = time.time() - start_time
 total_time_str = str(datetime.timedelta(seconds=int(total_time)))
 print('{} Total time: {} ({:.4f} s / it)'.format(
 header, total_time_str, total_time / len(iterable)))

def get_metrics(output, target, metrics, is_binary, threshold=0.5):
 if is_binary:
 if 'roc_auc' not in metrics or sum(target) * (len(target) - sum(target)) != 0: # to prevent all 0 or all 1 and raise the AUROC error
 results = binary_metrics_fn(
 target,
 output,
 metrics=metrics,
 threshold=threshold,
 )
 else:
 results = {
 "accuracy": 0.0,
 "balanced_accuracy": 0.0,
 "pr_auc": 0.0,
 "roc_auc": 0.0,
 }
 else:
 results = multiclass_metrics_fn(
 target, output, metrics=metrics
 )
 return results

@torch.no_grad()
def evaluate(data_loader, model, device, step_fn, header='Test:', metrics=['acc'], is_binary=True, use_amp=True):

 if is_binary:
 criterion = torch.nn.BCEWithLogitsLoss()
 else:
 criterion = torch.nn.CrossEntropyLoss()

 metric_logger = MetricLogger(delimiter=" ")
 #header = 'Test:'

 # switch to evaluation mode
 model.eval()
 model.to(device)
 pred = []
 true = []
 for step, batch in enumerate(metric_logger.log_every(data_loader, 10, header)):
 EEG = batch[0].to(device, non_blocking=True)
 target = batch[-1]
 target = target.to(device, non_blocking=True)
 if is_binary:
 target = target.float().unsqueeze(-1)

 # compute output
 if use_amp:
 with torch.cuda.amp.autocast():
 output = step_fn(model, EEG)
 loss = criterion(output, target)
 else:
 output = step_fn(model, EEG)
 loss = criterion(output, target)

 if is_binary:
 output = torch.sigmoid(output).cpu()
 else:
 output = output.cpu()
 target = target.cpu()

 results = get_metrics(output.numpy(), target.numpy(), metrics, is_binary)
 pred.append(output)
 true.append(target)

 batch_size = EEG.shape[0]
 metric_logger.update(loss=loss.item())
 for key, value in results.items():
 metric_logger.meters[key].update(value, n=batch_size)
 #metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
 # gather the stats from all processes
 # metric_logger.synchronize_between_processes()
 print('* loss {losses.global_avg:.3f}'
 .format(losses=metric_logger.loss))

 pred = torch.cat(pred, dim=0).numpy()
 true = torch.cat(true, dim=0).numpy()

 ret = get_metrics(pred, true, metrics, is_binary, 0.5)
 ret['loss'] = metric_logger.loss.global_avg
 return ret