# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional

import torch

#from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from mixup_hier import Mixup
from timm.data import Mixup as Mixup_single

from losses import DistillationLoss
import utils
import torch.nn.functional as F

def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True, args = None):
    model.train(set_training_mode)
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    
    if args.globalkl:
        gk_criterion = torch.nn.KLDivLoss(reduction='batchmean') 
    elif args.globalbce:
        gk_criterion = torch.nn.BCEWithLogitsLoss()
    if args.cosub:
        criterion = torch.nn.BCEWithLogitsLoss()
    
    criterion = torch.nn.CrossEntropyLoss() ######## added for not mixup

    

    if len(args.nb_classes) == 3:
        for samples, segments, targets, species_targets, family_targets, mf_targets, caps_embed in metric_logger.log_every(data_loader, print_freq, header):
            samples = samples.to(device, non_blocking=True)
            segments = segments.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            species_targets = species_targets.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            mf_targets = mf_targets.to(device, non_blocking=True)
            caps_embed = caps_embed.to(device, non_blocking=True)

            if mixup_fn is not None:
                """
                targets before: Bx1 [3, 10, .. ]
                after funcion: BxN_Class [[0.0033, 0.0033, 0.0033, 0.456, ...]...]
                """

                if 'BIRD' in args.data_set:
                    leaf_labels = torch.nonzero(targets > 50, as_tuple=False)
                    family_labels = torch.nonzero(targets > 12, as_tuple=False)
                elif 'IMNET-H' in args.data_set:
                    leaf_labels = torch.nonzero(targets > 146, as_tuple=False)
                    family_labels = torch.nonzero(targets > 19, as_tuple=False)     
                elif 'AIR' in args.data_set:
                    leaf_labels = torch.nonzero(targets > 99, as_tuple=False)
                    family_labels = torch.nonzero(targets > 29, as_tuple=False)            
              
                else:
                    raise ValueError('Unknown dataset')

                # for i in leaf_labels:
                #     sp_targets[i] = species_targets[i]
                # for i in family_labels:
                #     fm_targets[i] = family_targets[i]

                
                # w_s = 1/len(leaf_labels)
                # w_f = 1/len(family_labels)
                # w_m = 1/len(mf_targets)
            with torch.cuda.amp.autocast():
                outputs, family_out, manu_out, feats = model(samples, segments)###
                
                feats = feats / feats.norm(dim=-1, keepdim=True)
                caps_embed = caps_embed / caps_embed.norm(dim=-1, keepdim=True) 
                labels = torch.arange(len(targets)).to(device)
                logits = torch.matmul(feats, caps_embed.t()) 
                loss_i = F.cross_entropy(logits, labels)
                loss_t = F.cross_entropy(logits.t(), labels)
                sim_loss = (loss_i + loss_t) / 2

                loss_species = 0
                loss_family = 0
                loss_manufacturer = 0


                if not args.cosub:
                    if leaf_labels.shape[0] > 0:
                        select_leaf_output = torch.index_select(outputs, 0, leaf_labels.squeeze())

                        select_leaf_labels = torch.index_select(species_targets, 0, leaf_labels.squeeze())
                    
            
                        loss_species += (F.cross_entropy(select_leaf_output, select_leaf_labels))
                   
            
                    if family_labels.shape[0] > 0:
                        select_family_labels = torch.index_select(family_targets, 0, family_labels.squeeze())
                    
                        select_family_output = torch.index_select(family_out, 0, family_labels.squeeze())
                        loss_family += (F.cross_entropy(select_family_output, select_family_labels) )
                    
             
                    loss_manufacturer = (F.cross_entropy(manu_out, mf_targets))


                    loss = loss_species + loss_family + loss_manufacturer + sim_loss * args.sim_loss_weight
                    if args.globalkl:
                        all_outputs = torch.cat((torch.index_select(manu_out, 0, leaf_labels.squeeze()), torch.index_select(family_out, 0, leaf_labels.squeeze()), select_leaf_output), dim=1)
                        all_outputs = F.log_softmax(all_outputs, dim=1)
                        mf_onehot = F.one_hot(torch.index_select(mf_targets, 0, leaf_labels.squeeze()), num_classes=args.nb_classes[2]).float()
                        family_onehot = F.one_hot(torch.index_select(family_targets, 0, leaf_labels.squeeze()), num_classes=args.nb_classes[1]).float()
                        leaf_onehot = F.one_hot(select_leaf_labels, num_classes=args.nb_classes[0]).float()
                        all_targets = torch.cat((mf_onehot, family_onehot, leaf_onehot), dim=1)
                      
                        all_targets = F.normalize(all_targets, p=1, dim=1)  
                        gk_loss = gk_criterion(all_outputs, all_targets)
                        loss = loss + gk_loss * args.gk_weight

                else:
                    outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
                    loss = 0.25 * criterion(outputs[0], targets) 
                    loss = loss + 0.25 * criterion(outputs[1], targets) 
                    loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
                    loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            optimizer.zero_grad()

            # this attribute is added by timm on one optimizer (adahessian)
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                        parameters=model.parameters(), create_graph=is_second_order)

            torch.cuda.synchronize()
            if model_ema is not None:
                model_ema.update(model)

            #metric_logger.update(loss=loss_value)
            metric_logger.update(sp_loss=loss_species.item())
            metric_logger.update(fam_loss=loss_family.item())
            metric_logger.update(manu_loss=loss_manufacturer.item())
            metric_logger.update(sim_loss=sim_loss.item())
            if args.globalkl or args.globalbce:
                metric_logger.update(gk_loss=gk_loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            del feats, samples, targets, caps_embed, outputs, loss, loss_i, loss_t, sim_loss, logits
            torch.cuda.empty_cache()


    elif len(args.nb_classes) == 2:
        for samples, segments, targets, species_targets, family_targets in metric_logger.log_every(data_loader, print_freq, header):

            samples = samples.to(device, non_blocking=True)
            segments = segments.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            species_targets = species_targets.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            
            if mixup_fn is not None:
                """
                targets before: Bx1 [3, 10, .. ]
                after funcion: BxN_Class [[0.0033, 0.0033, 0.0033, 0.456, ...]...]
                """
                sp_targets = torch.full((targets.shape[0],), -1)
                leaf_labels = torch.nonzero(targets >= args.nb_classes[1], as_tuple=False)
                for i in leaf_labels:
                    sp_targets[i] = species_targets[i]

                samples, sp_targets, family_targets = mixup_fn(samples, [sp_targets, family_targets])
                

            if args.cosub:
                samples = torch.cat((samples,samples),dim=0)
                
            if args.bce_loss:
                targets = targets.gt(0.0).type(targets.dtype)
            
            with torch.cuda.amp.autocast():
                outputs, family_out = model(samples, segments)###
                if not args.cosub:
                    loss_species = 0
                    loss_family = 0

                    if leaf_labels.shape[0] > 0:
                        select_leaf_labels = torch.index_select(sp_targets, 0, leaf_labels.squeeze())
                        select_leaf_output = torch.index_select(outputs, 0, leaf_labels.squeeze())
                        select_leaf_samples = torch.index_select(samples, 0, leaf_labels.squeeze())
                        loss_species += criterion(select_leaf_samples, select_leaf_output, select_leaf_labels)
                        

                    loss_family = criterion(samples, family_out, family_targets)
                    loss = loss_species + loss_family 
                    if args.globalkl:
                        all_outputs = torch.cat((torch.index_select(family_out, 0, leaf_labels.squeeze()), select_leaf_output), dim=1)
                        all_outputs = F.log_softmax(all_outputs, dim=1)
                        all_targets = torch.cat((torch.index_select(family_targets, 0, leaf_labels.squeeze()), select_leaf_labels), dim=1)
                        all_targets = F.normalize(all_targets, p=1, dim=1)
                        
                        gk_loss = gk_criterion(all_outputs, all_targets)
                        loss = loss + gk_loss * args.gk_weight

                else:
                    outputs = torch.split(outputs, outputs.shape[0]//2, dim=0)
                    loss = 0.25 * criterion(outputs[0], targets) 
                    loss = loss + 0.25 * criterion(outputs[1], targets) 
                    loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
                    loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 

            loss_value = loss.item()

            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            optimizer.zero_grad()

            # this attribute is added by timm on one optimizer (adahessian)
            is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                        parameters=model.parameters(), create_graph=is_second_order)

            torch.cuda.synchronize()
            if model_ema is not None:
                model_ema.update(model)

            #metric_logger.update(loss=loss_value)
            metric_logger.update(sp_loss=loss_species.item())
            metric_logger.update(fam_loss=loss_family.item())
            if args.globalkl or args.globalbce:
                metric_logger.update(gk_loss=gk_loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(data_loader, model, device, nb_classes):
    criterion = torch.nn.CrossEntropyLoss()

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    # switch to evaluation mode
    model.eval()

    if len(nb_classes) == 3:
        for images, segments,  target, family_targets, mf_targets in metric_logger.log_every(data_loader, 10, header):
            images = images.to(device, non_blocking=True)
            segments = segments.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)
            mf_targets = mf_targets.to(device, non_blocking=True)
            # compute output
            with torch.cuda.amp.autocast():
                output, family_out, manu_out, _ = model(images, segments)

                loss_species = criterion(output, target)

                loss_family = criterion(family_out, family_targets)

                loss_manufacturer = criterion(manu_out, mf_targets)

            
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            family_acc1, family_acc5 = accuracy(family_out, family_targets, topk=(1, 5))

            manu_acc1, manu_acc5 = accuracy(manu_out, mf_targets, topk=(1, 5))

            batch_size = images.shape[0]
            metric_logger.update(sploss=loss_species.item())
            metric_logger.update(famloss=loss_family.item())
            metric_logger.update(manuloss=loss_manufacturer.item())
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['family_acc1'].update(family_acc1.item(), n=batch_size)
            metric_logger.meters['manu_acc1'].update(manu_acc1.item(), n=batch_size)

        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} family@1 {familytop1.global_avg:.3f}' 
            ' manu@1 {manutop1.global_avg:.3f} sploss {losses.global_avg:.3f} fmloss {fmlosses.global_avg:.3f} mfloss {mflosses.global_avg:.3f}'
            .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.sploss, fmlosses=metric_logger.famloss, mflosses=metric_logger.manuloss,
                    familytop1=metric_logger.family_acc1, manutop1=metric_logger.manu_acc1))
    
    elif len(nb_classes) == 2:
        for images, segments, target, family_targets in metric_logger.log_every(data_loader, 10, header):
            images = images.to(device, non_blocking=True)
            segments = segments.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            family_targets = family_targets.to(device, non_blocking=True)


            # compute output
            with torch.cuda.amp.autocast():
                output, family_out = model(images, segments)
                loss_species = criterion(output, target)
                loss_family = criterion(family_out, family_targets)
            
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            family_acc1, family_acc5 = accuracy(family_out, family_targets, topk=(1, 5))

            batch_size = images.shape[0]
            metric_logger.update(sploss=loss_species.item())
            metric_logger.update(famloss=loss_family.item())

            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
            metric_logger.meters['family_acc1'].update(family_acc1.item(), n=batch_size)
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} family@1 {familytop1.global_avg:.3f}' 
            'sploss {losses.global_avg:.3f} fmloss {fmlosses.global_avg:.3f}'
            .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.sploss, fmlosses=metric_logger.famloss,
                    familytop1=metric_logger.family_acc1))


    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

birds_trees = torch.tensor([
[1,12,35],
[2,12,35],
[3,12,35],
[4,6,9],
[5,4,4],
[6,4,4],
[7,4,4],
[8,4,4],
[9,8,18],
[10,8,18],
[11,8,18],
[12,8,18],
[13,8,18],
[14,8,13],
[15,8,13],
[16,8,13],
[17,8,13],
[18,8,26],
[19,8,21],
[20,8,19],
[21,8,24],
[22,3,3],
[23,13,37],
[24,13,37],
[25,13,37],
[26,8,18],
[27,8,18],
[28,8,14],
[29,8,15],
[30,8,15],
[31,6,9],
[32,6,9],
[33,6,9],
[34,8,16],
[35,8,16],
[36,10,33],
[37,8,30],
[38,8,30],
[39,8,30],
[40,8,30],
[41,8,30],
[42,8,30],
[43,8,30],
[44,13,38],
[45,12,36],
[46,1,1],
[47,8,16],
[48,8,16],
[49,8,18],
[50,11,34],
[51,11,34],
[52,11,34],
[53,11,34],
[54,8,13],
[55,8,16],
[56,8,16],
[57,8,13],
[58,4,4],
[59,4,5],
[60,4,5],
[61,4,5],
[62,4,5],
[63,4,5],
[64,4,5],
[65,4,5],
[66,4,5],
[67,2,2],
[68,2,2],
[69,2,2],
[70,2,2],
[71,4,6],
[72,4,6],
[73,8,15],
[74,8,15],
[75,8,15],
[76,8,24],
[77,8,30],
[78,8,30],
[79,5,7],
[80,5,7],
[81,5,7],
[82,5,7],
[83,5,7],
[84,5,8],
[85,8,11],
[86,7,10],
[87,1,1],
[88,8,18],
[89,1,1],
[90,1,1],
[91,8,21],
[92,3,3],
[93,8,15],
[94,8,27],
[95,8,18],
[96,8,18],
[97,8,18],
[98,8,18],
[99,8,23],
[100,9,32],
[101,9,32],
[102,8,30],
[103,8,30],
[104,8,22],
[105,3,3],
[106,4,4],
[107,8,15],
[108,8,15],
[109,8,23],
[110,6,9],
[111,8,20],
[112,8,20],
[113,8,24],
[114,8,24],
[115,8,24],
[116,8,24],
[117,8,24],
[118,8,25],
[119,8,24],
[120,8,24],
[121,8,24],
[122,8,24],
[123,8,24],
[124,8,24],
[125,8,24],
[126,8,24],
[127,8,24],
[128,8,24],
[129,8,24],
[130,8,24],
[131,8,24],
[132,8,24],
[133,8,24],
[134,8,28],
[135,8,17],
[136,8,17],
[137,8,17],
[138,8,17],
[139,8,13],
[140,8,13],
[141,4,5],
[142,4,5],
[143,4,5],
[144,4,5],
[145,4,5],
[146,4,5],
[147,4,5],
[148,8,24],
[149,8,21],
[150,8,21],
[151,8,31],
[152,8,31],
[153,8,31],
[154,8,31],
[155,8,31],
[156,8,31],
[157,8,31],
[158,8,23],
[159,8,23],
[160,8,23],
[161,8,23],
[162,8,23],
[163,8,23],
[164,8,23],
[165,8,23],
[166,8,23],
[167,8,23],
[168,8,23],
[169,8,23],
[170,8,23],
[171,8,23],
[172,8,23],
[173,8,23],
[174,8,23],
[175,8,23],
[176,8,23],
[177,8,23],
[178,8,23],
[179,8,23],
[180,8,23],
[181,8,23],
[182,8,23],
[183,8,23],
[184,8,23],
[185,8,12],
[186,8,12],
[187,10,33],
[188,10,33],
[189,10,33],
[190,10,33],
[191,10,33],
[192,10,33],
[193,8,29],
[194,8,29],
[195,8,29],
[196,8,29],
[197,8,29],
[198,8,29],
[199,8,29],
[200,8,23]
])
