import logging
import traceback
import torch
import sys
import os
import time
import torch.nn.functional as F
import wandb
import argparse
import json
from functools import partial
from torch.cuda.amp import GradScaler, autocast
import torch.nn.init as init


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.data import get_torch_dataset, get_model, get_batch
from utils.exp_utils import set_seed, get_optimizer, try_cuda, kl_div_logits
from utils.label_mapping import generate_label_mapping_by_frequency_ordinary, label_mapping_base


def evaluate(teacher, student, loader, epoch, args, teacher_label_mapping, student_label_mapping):
    teacher.eval()
    if student:
        student.eval()
    teacher_loss, student_loss = 0, 0
    teacher_correct, student_correct = 0, 0
    total = 0
    start = time.time()
    for batch in loader:
        with torch.no_grad():
            inputs, targets = try_cuda(*batch[:2])
            total += targets.size(0)
            if args.downstream_mapping == 'label_mapping':
                teacher_pred = teacher_label_mapping(F.log_softmax(teacher(inputs), dim=-1))
            else:
                teacher_pred=F.log_softmax(teacher(inputs), dim=-1)
            teacher_loss+=F.cross_entropy(teacher_pred, targets)
            teacher_correct+=teacher_pred.max(1)[1].eq(targets).sum().item()
            if student:
                if args.downstream_mapping == 'label_mapping':
                    student_pred=student_label_mapping(F.log_softmax(student(inputs), dim=-1)) 
                else:
                    student_pred=F.log_softmax(student(inputs), dim=-1)    
                student_loss+=F.cross_entropy(student_pred, targets)        
                student_correct+=student_pred.max(1)[1].eq(targets).sum().item()
    end = time.time()
    step=epoch
    print('[eval] Epoch: %d | Teacher Test Loss: %.3f | Teacher Test Acc: %.3f | Student Test Loss: %.3f | Student Test Acc: %.3f | Time: %.3f |'
            % (step, teacher_loss / len(loader), 100. * teacher_correct / total, student_loss / len(loader), 100. * student_correct / total, end-start))
    wandb.log({'teacher test acc': 100. * teacher_correct / total, 'student test acc': 100. * student_correct / total}, step=step)


def train(teacher, student, loader, epoch, args, teacher_optimizer, student_optimizer, teacher_scheduler, student_scheduler, teacher_label_mapping, student_label_mapping):
    scaler = GradScaler()
    teacher.train()
    if student:
        student.train()
    total_loss = 0
    total_ce_loss = 0
    total_lot_loss = 0
    student_correct, teacher_correct = 0, 0
    total = 0
    start = time.time()
    for idx, (inputs, targets) in enumerate(loader):
        inputs, targets = try_cuda(inputs, targets)
        # update teacher
        teacher_optimizer.zero_grad()
        with autocast():
            if args.downstream_mapping == 'label_mapping':
                teacher_pred = teacher_label_mapping(F.log_softmax(teacher(inputs), dim=-1))
            else:
                teacher_pred = F.log_softmax(teacher(inputs), dim=-1)
            if student:
                student_optimizer.zero_grad()
                if args.downstream_mapping == 'label_mapping':
                    student_pred=student_label_mapping(F.log_softmax(student(inputs), dim=-1)) 
                else:
                    student_pred=F.log_softmax(student(inputs), dim=-1) 
            else:
                student_pred=0
            teacher_loss, teacher_ce_loss, teacher_lot_loss = choose_loss(teacher_pred, student_pred, targets, args.loss, 'teacher', args)           
        scaler.scale(teacher_loss).backward()
        scaler.step(teacher_optimizer)
        scaler.update()
        teacher_correct+=teacher_pred.max(1)[1].eq(targets).sum().item()
        if student and args.sim_update:
            student_loss, student_ce_loss, student_lot_loss = choose_loss(teacher_pred, student_pred, targets, args.loss, 'student', args)
            scaler.scale(student_loss).backward()
            scaler.step(student_optimizer)
            scaler.update()           
            student_correct+=student_pred.max(1)[1].eq(targets).sum().item()
        total_loss += teacher_loss
        total_ce_loss += teacher_ce_loss
        total_lot_loss += teacher_lot_loss

        total += targets.size(0)
        # student additional train
        if student and not args.sim_update:
            for _ in range(args.student_steps_ratio):
                s_inputs, s_targets = get_batch(loader, args.student_index)
                s_inputs, s_targets = try_cuda(s_inputs, s_targets)
                args.student_index = (args.student_index+1) % len(loader)
                with autocast():
                    if args.downstream_mapping == 'label_mapping':
                        teacher_pred = teacher_label_mapping(F.log_softmax(teacher(s_inputs), dim=-1))
                        student_pred = student_label_mapping(F.log_softmax(student(s_inputs), dim=-1))
                    else:
                        teacher_pred=F.log_softmax(teacher(s_inputs), dim=-1)
                        student_pred=F.log_softmax(student(s_inputs), dim=-1)
                    student_loss, student_ce_loss, student_lot_loss = choose_loss(teacher_pred, student_pred, s_targets, args.loss, 'student', args)
                student_optimizer.zero_grad()
                scaler.scale(student_loss).backward()
                scaler.step(student_optimizer)
                scaler.update()
    end = time.time()
    step=epoch
    print((
        f'[Train] Epoch: {step} | '
        f'Teacher lr={teacher_scheduler.get_last_lr()[0]:.4f} | '
        f'Total Loss: {total_loss / len(loader):.3f} | '
        f'CE loss: {total_ce_loss / len(loader):.3f} | '
        f'LoT loss: {total_lot_loss / len(loader):.3f} | '
        f'Teacher Train Acc: {100. * teacher_correct / total:.3f} | '
        f'Student Train Acc: {100. * student_correct / total:.3f} | '
        f'Time: {end-start:.3f} |'
    ))
    wandb.log({'teacher_lr': teacher_scheduler.get_last_lr()[0], 'teacher_total_loss': total_loss / len(loader), \
               'teacher_ce_loss': total_ce_loss / len(loader), 'teacher_lot_loss': total_lot_loss / len(loader), \
                'teacher_train_acc': 100. * teacher_correct / total, 'student_train_acc': 100. * student_correct / total}, step=step)


def choose_loss(teacher_pred, student_pred, targets, type, calculate, args):
    if not isinstance(student_pred, int) and type=='kl_ce':
        if calculate == 'teacher':
            ce_loss = F.cross_entropy(teacher_pred, targets)
            lot_loss = args.alpha*kl_div_logits(teacher_pred, student_pred.detach(), args.T)
        else:
            ce_loss = F.cross_entropy(student_pred, targets) 
            lot_loss = args.alpha*kl_div_logits(student_pred, teacher_pred.detach(), args.T)
    elif not isinstance(student_pred, int) and args.loss=='kl':
        if calculate == 'teacher':
            ce_loss = F.cross_entropy(teacher_pred, targets)
            lot_loss = args.alpha*kl_div_logits(teacher_pred, student_pred.detach(), args.T)
        else:
            ce_loss = 0
            lot_loss = kl_div_logits(student_pred, teacher_pred.detach(), args.T)
    elif not isinstance(student_pred, int) and args.loss=='symmetric_kl':
        if calculate == 'teacher':
            ce_loss = F.cross_entropy(teacher_pred, targets)
            lot_loss = args.alpha*(kl_div_logits(teacher_pred, student_pred.detach(), args.T)+kl_div_logits(student_pred.detach(), teacher_pred, args.T))
        else:
            ce_loss = 0
            lot_loss = kl_div_logits(student_pred, teacher_pred.detach(), args.T) + kl_div_logits(teacher_pred.detach(), student_pred, args.T)
    elif not isinstance(student_pred, int) and args.loss=='symmetric_kl_ce':
        if calculate == 'teacher':
            ce_loss = F.cross_entropy(teacher_pred, targets)
            lot_loss = args.alpha*(kl_div_logits(teacher_pred, student_pred.detach(), args.T)+kl_div_logits(student_pred.detach(), teacher_pred, args.T))
        else:
            ce_loss = F.cross_entropy(student_pred, targets)
            lot_loss = args.alpha * (kl_div_logits(student_pred, teacher_pred.detach(), args.T) + kl_div_logits(teacher_pred.detach(), student_pred, args.T))
            total_loss = ce_loss + lot_loss
    else:
        ce_loss = F.cross_entropy(teacher_pred, targets)
        lot_loss = 0
    total_loss = ce_loss + lot_loss
        
    return total_loss, ce_loss, lot_loss


def single_train(network, loader, args, optimizer, scheduler, label_mapping):
    scaler = GradScaler()
    network.train()
    total_loss = 0
    total_correct = 0, 0
    total = 0
    start = time.time()
    for idx, (inputs, targets) in enumerate(loader):
        inputs, targets = try_cuda(inputs, targets)
        # update teacher
        optimizer.zero_grad()
        with autocast():
            if args.downstream_mapping == 'label_mapping':
                pred = label_mapping(F.log_softmax(network(inputs), dim=-1))
            else:
                pred = F.log_softmax(network(inputs), dim=-1)
            loss, ce_loss, lot_loss = choose_loss(pred, 0, targets, args.loss, 'teacher', args)           
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        correct+=pred.max(1)[1].eq(targets).sum().item()
        total_loss += loss
        total_ce_loss += ce_loss
        total_lot_loss += lot_loss

        total += targets.size(0)
    end = time.time()
    print((
        f'[Single Train] '
        f'Total Loss: {total_loss / len(loader):.3f} | '
        f'CE loss: {total_ce_loss / len(loader):.3f} | '
        f'LoT loss: {total_lot_loss / len(loader):.3f} | '
        f'Train Acc: {100. * correct / total:.3f} | '
        f'Time: {end-start:.3f} |'
    ))


parser = argparse.ArgumentParser(description='PyTorch Image Classification')
# LoT
parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--detach', type=int, default=1)
parser.add_argument('--T', type=float, default=1.5)
parser.add_argument('--student_index', type=int, default=0, help='an independent index for student updating')
parser.add_argument('--sim_update', type=int, default=0, choices=[1, 0])
parser.add_argument('--student_steps_ratio', type=int, default=1)
parser.add_argument('--loss', type=str, default='kl_ce', choices=['kl', 'kl_ce', 'symmetric_kl', 'symmetric_kl_ce'])
parser.add_argument('--teacher_network', type=str, default='resnet18', choices=['resnet18', 'resnet50', 'mobilenet', 'vit_b', 'vit_l'])
parser.add_argument('--teacher_pretrain', type=int, default=1, choices=[1, 0])
parser.add_argument('--student_network', type=str, default='resnet18', choices=['', 'resnet18', 'resnet50', 'mobilenet', 'vit_b', 'vit_l'])
parser.add_argument('--student_pretrain', type=int, default=1, choices=[1, 0])
# all
parser.add_argument('--dataset', type=str, default='cifar100', choices = ['cifar10', 'cifar100', 'ImageNet'])
parser.add_argument('--datadir', type=str, default='../data', help='data directory')
parser.add_argument('--input_size', type=int, default=224, help='image input size')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=0,help='random seed')
parser.add_argument('--optimizer', type=str, default='sgd')
parser.add_argument('--lr', type=float, default=0.02)
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--scheduler', type=str, default='cosine')
parser.add_argument('--decreasing_step', default=[0.5,0.72], type = list, help='decreasing strategy')
parser.add_argument('--epochs', type=int, default=40)
parser.add_argument('--eval_frequency', type=int, default=5)
parser.add_argument('--warm_up', type=int, default=0)
parser.add_argument('--downstream_mapping', type=str, default=0, choices=['origin', 'label_mapping', 'linear_probing'])
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str,  default='ckpt/LoT_ResNet'+randomhash+'CIFAR', help='path to save the final model')
parser.add_argument('--exp_name', type=str, default='LoT_ResNet')
args = parser.parse_args()
print(json.dumps(vars(args), indent=4))


def main():
    try:
        wandb_username=os.environ.get('WANDB_USER_NAME')
        wandb_key=os.environ.get('WANDB_API_KEY')    
        wandb.login(key=wandb_key)
        wandb.init(project='LoT_Image_Classification_'+args.dataset, entity=wandb_username, name=args.exp_name)

        # Device
        device = torch.device(f"cuda:{args.gpu}")
        torch.cuda.set_device(int(args.gpu))
        set_seed(args.seed)

        # dataset
        train_loader, test_loader = get_torch_dataset(args, 'ff')

        # init teacher
        print('Teacher Network:', args.teacher_network)
        teacher=get_model(args.teacher_network, args.teacher_pretrain)
        teacher=teacher.cuda()
        total_params = sum(p.numel() for p in teacher.parameters())
        teacher_optimizer, teacher_scheduler = get_optimizer(teacher.parameters(), args)
        if args.downstream_mapping == 'label_mapping':
            mapping_sequence = generate_label_mapping_by_frequency_ordinary(teacher, train_loader)
            teacher_label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
            print('teacher mapping sequence:', mapping_sequence)
        elif args.downstream_mapping == 'linear_probing':
            teacher_label_mapping = None
            if args.teacher_network == 'mobilenet':
                teacher.classifier[1] = torch.nn.Linear(teacher.classifier[1].in_features, args.class_cnt).cuda()
            else:
                teacher.fc = torch.nn.Linear(teacher.fc.in_features, args.class_cnt).cuda()
        else:
            teacher_label_mapping = None
        print(f"Total number of teacher parameters: {total_params:,}")

        # init student
        print('Student Network:', args.student_network)
        if args.student_network:
            student=get_model(args.student_network, args.student_pretrain)
            student=student.cuda()
            args.student_index=0
            total_params = sum(p.numel() for p in student.parameters())
            student_optimizer, student_scheduler = get_optimizer(student.parameters(), args)
            if args.downstream_mapping == 'label_mapping':
                mapping_sequence = generate_label_mapping_by_frequency_ordinary(student, train_loader)
                student_label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
                print('student mapping sequence:', mapping_sequence)
            elif args.downstream_mapping == 'linear_probing':
                student_label_mapping = None
                if args.student_network == 'mobilenet':
                    student.classifier[1] = torch.nn.Linear(student.classifier[1].in_features, args.class_cnt).cuda()
                    init.xavier_uniform_(student.classifier[1].weight)
                else:
                    student.fc = torch.nn.Linear(student.fc.in_features, args.class_cnt).cuda()
                    init.xavier_uniform_(student.fc.weight)
            else:
                student_label_mapping = None
            print(f"Total number of student parameters: {total_params:,}")
        else:
            student=None
            student_optimizer=None
            student_scheduler=None
            student_label_mapping = None

        print(f"==== train and evaluate unequal restart ====")
        evaluate(teacher, student, test_loader, 0, args, teacher_label_mapping, student_label_mapping)
        if args.warm_up:
            print('Warmup Starts')
            for epoch in range(1, args.warm_up+1):
                train(teacher, None, train_loader, 0, args, teacher_optimizer, student_optimizer, teacher_scheduler, student_scheduler, teacher_label_mapping, student_label_mapping)
                train(student, None, train_loader, 0, args, student_optimizer, teacher_optimizer, student_scheduler, teacher_scheduler, student_label_mapping, teacher_label_mapping)
            print('Warmup Ends')
        
        for epoch in range(1, args.epochs + 1):
            train(teacher, student, train_loader, epoch, args, teacher_optimizer, student_optimizer, teacher_scheduler, student_scheduler, teacher_label_mapping, student_label_mapping)
            teacher_scheduler.step()
            if student:
                student_scheduler.step()
            if epoch % args.eval_frequency == 0:
                evaluate(teacher, student, test_loader, epoch, args, teacher_label_mapping, student_label_mapping)
            torch.save(teacher.state_dict(), args.save+'_teacher.pt')
            if student:
                torch.save(student.state_dict(), args.save+'_student.pt')
        print('ckpt location:', args.save)
        wandb.finish()

    except Exception:
        logging.error(traceback.format_exc())
        return float('NaN')


if __name__ == '__main__':
    main()
