import torch

import math
import numpy as np

from core.evaluate import accuracy
from utils.utils import mixup_data, mixup_criterion


class Trainer:
    def __init__(self, cfg, rank):
        self.cfg = cfg
        self.type = cfg.train.trainer.type
        self.rank = rank
        self.num_epochs = cfg.train.num_epochs
        self.init_all_params()

    def init_all_params(self):
        self.mixup_alpha = self.cfg.train.trainer.mixup_alpha

    def reset_epoch(self, epoch):
        self.epoch = epoch

    def forward(self, model, criterion, data, targets, **kwargs):
        return getattr(Trainer, self.type)(
            self, model, criterion, data, targets, **kwargs
        )

    def default(self, model, criterion, data, targets, **kwargs):
        data, targets = data.cuda(self.rank), targets.cuda(self.rank)

        output = model(data)

        loss = criterion(output, targets)
        pred = torch.argmax(output, 1)
        acc = accuracy(pred.cpu().numpy(), targets.cpu().numpy())[0]

        return loss, acc

    def mixup(self, model, criterion, data, targets, **kwargs):
        data, targets = data.cuda(self.rank), targets.cuda(self.rank)
        mixed_x, y_a, y_b, lam = mixup_data(
            data, targets, 
            alpha=self.mixup_alpha, rank=self.rank
        )

        output = model(mixed_x)

        loss = mixup_criterion(criterion, output, y_a, y_b, lam)
        pred = torch.argmax(output, 1)
        acc = accuracy(pred.cpu().numpy(), targets.cpu().numpy())[0]
        return loss, acc

    def default_hvd(self, model, criterion, data, targets, **kwargs):
        data, targets = data.cuda(self.rank), targets.cuda(self.rank)

        if self.cfg.hvd.fp16:
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, targets)
        else:
            output = model(data)
            loss = criterion(output, targets)

        pred = torch.argmax(output, 1)
        acc = accuracy(pred.cpu().numpy(), targets.cpu().numpy())[0]

        return loss, acc

    def contrast(self, model, criterion, data, targets, **kwargs):
        data = torch.cat([data[0], data[1]], dim=0)
        data, targets = data.cuda(self.rank), targets.cuda(self.rank)
        bsz = targets.shape[0]

        features = model(data, feature_flag=True)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        if self.cfg.train.trainer.contrast_type == 'SupCon':
            loss = criterion(features, targets)
        elif self.cfg.train.trainer.contrast_type == 'SimCLR':
            loss = criterion(features)

        return loss, 0.

