# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

from mixup_partial import Mixup
from timm.utils import accuracy, ModelEma

from losses import DistillationLoss
import utils

import torch.nn.functional as F
import time
from collections import Counter

def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    
    if args.cosub:
        criterion = torch.nn.BCEWithLogitsLoss()
    
    w_s = 1/args.nb_classes[0]
    w_f = 1/args.nb_classes[1]
    w_m = 1/args.nb_classes[2]


    if len(args.nb_classes) == 3:
        for samples, targets, species_targets, family_targets, mf_targets, caps_embed in metric_logger.log_every(data_loader, print_freq, header):
            samples = samples.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            species_targets = species_targets.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            mf_targets = mf_targets.to(device, non_blocking=True)
            caps_embed = caps_embed.to(device, non_blocking=True)

            if 'BIRD' in args.data_set:
                leaf_labels = torch.nonzero(targets > 50, as_tuple=False)
                family_labels = torch.nonzero(targets > 12, as_tuple=False)

            elif 'IMNET-H' in args.data_set:
                leaf_labels = torch.nonzero(targets > 146, as_tuple=False)
                family_labels = torch.nonzero(targets > 19, as_tuple=False)    

            elif 'AIR' in args.data_set:
                leaf_labels = torch.nonzero(targets > 99, as_tuple=False)
                family_labels = torch.nonzero(targets > 29, as_tuple=False)    

            with torch.cuda.amp.autocast():
                outputs, family_out, manu_out, feats = model(samples)

                feats = feats / feats.norm(dim=-1, keepdim=True)
                caps_embed = caps_embed / caps_embed.norm(dim=-1, keepdim=True) 
                labels = torch.arange(len(targets)).to(device)
                logits = torch.matmul(feats, caps_embed.t()) 
                loss_i = F.cross_entropy(logits, labels)
                loss_t = F.cross_entropy(logits.t(), labels)

                sim_loss = (loss_i + loss_t) / 2
                
                loss_species = 0
                loss_family = 0
                loss_manufacturer = 0


                if leaf_labels.shape[0] > 0:

                    select_leaf_labels = torch.index_select(species_targets, 0, leaf_labels.squeeze())

                    select_leaf_output = torch.index_select(outputs, 0, leaf_labels.squeeze())

                    loss_species += (F.cross_entropy(select_leaf_output, select_leaf_labels))# * w_s / (w_s + w_f + w_m))

        
                if family_labels.shape[0] > 0:
                    select_family_labels = torch.index_select(family_targets, 0, family_labels.squeeze())

                    select_family_output = torch.index_select(family_out, 0, family_labels.squeeze())
                    loss_family += (F.cross_entropy(select_family_output, select_family_labels))#  * w_f / (w_s + w_f + w_m))

                loss_manufacturer = (F.cross_entropy(manu_out, mf_targets))#* w_m / (w_s + w_f + w_m))


                loss = loss_species + loss_family + loss_manufacturer + sim_loss * args.sim_loss_weight

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            optimizer.zero_grad()

            # this attribute is added by timm on one optimizer (adahessian)
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                        parameters=model.parameters(), create_graph=is_second_order)

            torch.cuda.synchronize()
            if model_ema is not None:
                model_ema.update(model)

            metric_logger.update(sp_loss=loss_species.item())
            metric_logger.update(fam_loss=loss_family.item())
            metric_logger.update(manu_loss=loss_manufacturer.item())
            metric_logger.update(sim_loss=sim_loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])


            del feats, samples, targets, caps_embed, outputs, loss, loss_i, loss_t, sim_loss, logits
            torch.cuda.empty_cache()

        metric_logger.synchronize_between_processes()
        print("Averaged stats:", metric_logger)
        return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    else:
        for samples, targets, family_targets in metric_logger.log_every(data_loader, print_freq, header):
            samples = samples.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            if mixup_fn is not None:
                samples, targets, family_targets = mixup_fn(samples, [targets, family_targets])
                
                
            if args.cosub:
                samples = torch.cat((samples,samples),dim=0)
                
            if args.bce_loss:
                targets = targets.gt(0.0).type(targets.dtype)
            
            with torch.cuda.amp.autocast():
                outputs, family_out = model(samples)
                if not args.cosub:
                    loss_species = criterion(samples, outputs, targets)
                    loss_family = criterion(samples, family_out, family_targets)
                    loss = loss_species + loss_family 

                else:
                    outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
                    loss = 0.25 * criterion(outputs[0], targets) 
                    loss = loss + 0.25 * criterion(outputs[1], targets) 
                    loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
                    loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            optimizer.zero_grad()

            # this attribute is added by timm on one optimizer (adahessian)
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                        parameters=model.parameters(), create_graph=is_second_order)

            torch.cuda.synchronize()
            if model_ema is not None:
                model_ema.update(model)

            metric_logger.update(sp_loss=loss_species.item())
            metric_logger.update(fam_loss=loss_family.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print("Averaged stats:", metric_logger)
        return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(data_loader, model, device, n_classes=3):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    if n_classes == 3:
        for images, target, family_targets, mf_targets in metric_logger.log_every(data_loader, 10, header):
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            mf_targets = mf_targets.to(device, non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast():
                output, family_out, manu_out, _ = model(images)
                loss_species = criterion(output, target)
                loss_family = criterion(family_out, family_targets)
                loss_manufacturer = criterion(manu_out, mf_targets)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            family_acc1, family_acc5 = accuracy(family_out, family_targets, topk=(1, 5))
            manu_acc1, manu_acc5 = accuracy(manu_out, mf_targets, topk=(1, 5))

            batch_size = images.shape[0]
            metric_logger.update(sploss=loss_species.item())
            metric_logger.update(famloss=loss_family.item())
            metric_logger.update(manuloss=loss_manufacturer.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['family_acc1'].update(family_acc1.item(), n=batch_size)
            metric_logger.meters['manu_acc1'].update(manu_acc1.item(), n=batch_size)
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} family@1 {familytop1.global_avg:.3f}' 
            ' manu@1 {manutop1.global_avg:.3f} sploss {losses.global_avg:.3f} fmloss {fmlosses.global_avg:.3f} mfloss {mflosses.global_avg:.3f}'
            .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.sploss, fmlosses=metric_logger.famloss, mflosses=metric_logger.manuloss,
                    familytop1=metric_logger.family_acc1, manutop1=metric_logger.manu_acc1))
        return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

    else:
        for images, target, family_targets in metric_logger.log_every(data_loader, 10, header):
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast():
                output, family_out = model(images)
                loss_species = criterion(output, target)
                loss_family = criterion(family_out, family_targets)

            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            family_acc1, family_acc5 = accuracy(family_out, family_targets, topk=(1, 5))

            batch_size = images.shape[0]
            metric_logger.update(sploss=loss_species.item())
            metric_logger.update(famloss=loss_family.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['family_acc1'].update(family_acc1.item(), n=batch_size)
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} family@1 {familytop1.global_avg:.3f}' 
            ' sploss {losses.global_avg:.3f} fmloss {fmlosses.global_avg:.3f}'
            .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.sploss, fmlosses=metric_logger.famloss, 
                    familytop1=metric_logger.family_acc1))
        return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
