import argparse
import os
import time
import numpy as np
import random
import tqdm
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.cuda
import torch.optim
import torch.utils.data

from myutils import Logger
from myutils import save_model
from networks.resnet_lit import get_resnet
from pr_resnet import PR
from get_dataloader_old import get_dataloader
from distillation_loss import AlphaDistillationLoss
from myutils import str2bool
from init import Initial

NB_SECTIONS = 4


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=-1, type=int)
    parser.add_argument('--wandb_project_name', type=str, default=None) 
    parser.add_argument('--manual_seed', type=int, default=0) 
    parser.add_argument('--task', type=str, default='nas')
    parser.add_argument('--folder_name', type=str, default='debug')
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--image_size', type=int, default=64)
    parser.add_argument('--ds_name', type=str, default='stl10')
    parser.add_argument('--ds_split', type=str, default=None)
    parser.add_argument('--net_info_path', type=str, 
                            default=None)
    parser.add_argument('--net_index', type=int, default=0)
    # LIT Hyperparams
    parser.add_argument('--beta', type=float, default=0.75)
    parser.add_argument('--use_noise', type=str2bool, default=False)
    # LIT Training
    parser.add_argument('--lit_lr', type=float, default=0.1)
    parser.add_argument('--lit_schedule', nargs='+', type=int, default=[100])
    parser.add_argument('--lit_epochs', type=int, default=10) # 175
    parser.add_argument('--valid_frequency', type=int, default=5)
    parser.add_argument('--save_per_batch', type=str2bool, default=True)
    # Finetuning
    parser.add_argument('--finetune_starting_lr', type=float, default=0.01)
    parser.add_argument('--finetune_schedule', nargs='+', type=int, default=[55])
    parser.add_argument('--finetuning_epochs', type=int, default= 1) # 125
    # KD Hyperparams
    parser.add_argument('--alpha', type=float, default=0.95)
    parser.add_argument('--temp', type=float, default=6.)
    # Rest of the hyperparams
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    # PR
    parser.add_argument('--pr_type', type=str, default='copy_paste_first')
    parser.add_argument('--metapr_ver', type=int, default=0)
    # NAS
    parser.add_argument('--proxy', type=str, default='flops')
    parser.add_argument('--topk', type=int, default=1)

    args = parser.parse_args()
    base_configs = ['common.yaml']
    initial = Initial(args, base_configs=base_configs)
    args = initial.args

    if args.task is not None:
        if args.task == 'nas':
            args.net_info_path = DATAPATH
            main_path = f'../exp/{args.search_space}/{args.tc_net_name}/{args.task}/{args.ds_name}/{args.proxy}/top-{args.topk}'
            args.net_index =  args.topk_idx[str(args.proxy)][str(args.ds_name)][args.topk]
            print(f'==> Net Idx. searched by {args.proxy} is {args.net_index}...')
            main_path += f'/net-{args.net_index}'
    else:
        main_path = f'../exp/{args.folder_name}/{args.search_space}/{args.tc_net_name}'
        main_path += f'/chmul-{args.channel_mul}'
        main_path += f'/imsz-{args.image_size}'
        main_path += f'/same_w-{str(args.same_width)[0]}'
        main_path += f'/{args.ds_name}'
        if args.ds_split is not None:
            main_path += f'/task-{args.ds_split}'

    main_path += f'/kd-{args.lit_epochs}-ft-{args.finetuning_epochs}'
    main_path += f'/beta-{args.beta}/alpha-{args.alpha}/temp-{args.temp}'
    main_path += f'/prtype-{args.pr_type}'
    main_path += '/lit'
    if args.task != 'nas':
        main_path += f'/net-{args.net_index}'

    print(f'==> main path : {main_path}')

    save_path = os.path.join(main_path, 'checkpoint')
    exp_name = main_path.replace('/', '_')
    exp_suffix = "" 

    os.makedirs(main_path, exist_ok=True)
    os.makedirs(save_path, exist_ok=True)

    os.environ['CUDA_VISIBLE_DEVICES']= args.gpu
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Determinism
    if args.manual_seed >= 0:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(args.manual_seed)
        torch.cuda.manual_seed_all(args.manual_seed)
        np.random.seed(args.manual_seed)
        random.seed(args.manual_seed)

    # Dataloader
    mode = 'meta_train' if args.ds_name == 'tiny_imagenet' else 'meta_test'
    train_loader, val_loader, n_classes = get_dataloader(
        args.default_data_path,
        mode='meta_test',
        image_size=args.image_size,
        batch_size=args.batch_size,
        ds_name=args.ds_name,
        ds_split=args.ds_split,
        mtrn_hetero_on=False,
        mtst_subset_on=False,
        class_split_ratio=[0.4, 0.7, 1.0],
        instance_split_ratio=0.7)
    print("Loaded Data")

    teacher, student, net_info = setup_teacher_student(args, n_classes, device)
    print("Loaded and Created Teacher + Student Models")

    logger = Logger(
                log_dir=main_path,
                exp_name=exp_name,
                exp_suffix=exp_suffix,
                write_textfile=True,
                # use_wandb=True if 'wdb' in args.folder_name else False,
                use_wandb=False,
                wandb_project_name=args.wandb_project_name,
                )
    logger.update_config(args, is_args=True)

    logger.update_config({
        'flops': net_info[0],
        'params': net_info[1],
        'depth_config': net_info[2],
        'channel_widths': net_info[3]
    })
    
    # Criterion
    ir_loss = torch.nn.MSELoss()
    kd_loss = AlphaDistillationLoss(temperature=args.temp, alpha=args.alpha)

    # LIT Training
    optimizer = torch.optim.SGD(student.parameters(), lr=args.lit_lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lit_schedule)
    pbar = tqdm.trange(args.lit_epochs)
    ## init valid loss and acc
    val_loss, val_acc1, val_acc5 = lit_epoch(
                args, -1, logger, teacher, student, val_loader, optimizer, scheduler,
                ir_loss, kd_loss, save_path, train=False)
    logger.update_config({
        'init_valid_loss': val_loss,
        'init_valid_top1': val_acc1,
        'init_valid_top5': val_acc5
    })

    total_st_time = time.time()
    best_acc = -1
    for epoch in pbar:
        st_epoch_time = time.time()
        train_loss, train_acc1, train_acc5 = lit_epoch(
                args, epoch, logger, teacher, student, train_loader, optimizer, scheduler, 
                ir_loss, kd_loss, save_path, train=True)

        pbar.set_description(f'LIT Epoch [{epoch+1}/{args.lit_epochs}]\t'
                +f'Train Loss {train_loss:.2f}\t Train Acc1 {train_acc1:.2f}')
                
        if (epoch + 1) % args.valid_frequency == 0:
            val_loss, val_acc1, val_acc5 = lit_epoch(
                    args, epoch, logger, teacher, student, val_loader, optimizer, scheduler,
                    ir_loss, kd_loss, save_path, train=False)
                    
            is_best = val_acc1 > best_acc
            best_acc = max(val_acc1, best_acc)

            logger.write_log_nohead({
                    'epoch': epoch+1,
                    'train/loss': train_loss,
                    'train/top1': train_acc1,
                    'train/top5': train_acc5,
                    'valid/loss': val_loss,
                    'valid/top1': val_acc1,
                    'valid/top5': val_acc5,
                    'valid/best_acc': best_acc,
                    'epoch_time': time.time() - st_epoch_time
                }, step=epoch+1)

            save_model({'epoch': epoch+1,
                        'best_acc': best_acc,
                        'optimizer': optimizer.state_dict(),
                        'state_dict': student.state_dict(),
                        }, save_path, is_best=is_best, model_name=None)

            print(f'Valid Loss {val_loss:.2f}\t Valid Acc1 {val_acc1:.2f}')
            print('Best Acc so far: {:.2f}'.format(best_acc))
    logger.save_log()

# ---------------------------------------------------------------------------

    # Fine Tuning
    main_path = main_path.replace('lit', 'finetuning')
    save_path = os.path.join(main_path, 'checkpoint')
    exp_name = main_path.replace('/', '_')
    exp_suffix = "" 

    os.makedirs(main_path, exist_ok=True)
    os.makedirs(save_path, exist_ok=True)

    logger_ft = Logger(
                log_dir=main_path,
                exp_name=exp_name,
                exp_suffix=exp_suffix,
                write_textfile=True,
                # use_wandb=True if 'wdb' in args.folder_name else False,
                use_wandb=False,
                wandb_project_name=args.wandb_project_name,
                )
    logger_ft.update_config(args, is_args=True)

    logger_ft.update_config({
        'flops': net_info[0],
        'params': net_info[1],
        'depth_config': net_info[2],
        'channel_widths': net_info[3]
    })

    val_loss, val_acc1, val_acc5 = fine_tune_epoch(teacher, student, optimizer, scheduler, val_loader, kd_loss, train=False)
    logger_ft.update_config({
        'init_valid_loss': val_loss,
        'init_valid_top1': val_acc1,
        'init_valid_top5': val_acc5
    })

    optimizer = torch.optim.SGD(student.parameters(), lr=args.finetune_starting_lr, 
                                momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.finetune_schedule)
    pbar = tqdm.trange(args.finetuning_epochs)
    for epoch in pbar:
        st_epoch_time = time.time()
        train_loss, train_acc1, train_acc5 = fine_tune_epoch(teacher, student, optimizer, scheduler, train_loader, kd_loss, train=True)

        pbar.set_description(f'Fine Tuning Epoch [{epoch+1}/{args.finetuning_epochs}]\t'
                +f'Train Loss {train_loss:.2f}\t Train Acc1 {train_acc1:.2f}')
        if (epoch + 1) % args.valid_frequency == 0:
            val_loss, val_acc1, val_acc5 = fine_tune_epoch(teacher, student, optimizer, scheduler, val_loader, kd_loss, train=False)
            
            is_best = val_acc1 > best_acc
            best_acc = max(val_acc1, best_acc)

            logger_ft.write_log_nohead({
                    'epoch': epoch+1,
                    'train/loss': train_loss,
                    'train/top1': train_acc1,
                    'train/top5': train_acc5,
                    'valid/loss': val_loss,
                    'valid/top1': val_acc1,
                    'valid/top5': val_acc5,
                    'valid/best_acc': best_acc,
                    'epoch_time': time.time() - st_epoch_time
                }, step=epoch+1)

            save_model({'epoch': epoch+1,
                        'best_acc': best_acc,
                        'optimizer': optimizer.state_dict(),
                        'state_dict': student.state_dict(),
                        }, save_path, is_best=is_best, model_name=None)

            print(f'Valid Loss {val_loss:.2f}\t Valid Acc1 {val_acc1:.2f}')
            print(f'Best Acc so far: {best_acc: .2f}')

    total_elapsed_time = time.time() - total_st_time
    print(f'total elapsed time: {total_elapsed_time:.3f} (s)')
    logger_ft.update_config({'total_time': total_elapsed_time})
    logger_ft.save_log()
    



def fine_tune_epoch(teacher, student, optimizer, scheduler, loader, kd_loss, train=True):
    teacher.eval()
    if train:
        student.train()
        # scheduler.step()
    else:
        student.eval()

    losses = AverageMeter()
    accuracies1 = AverageMeter()
    accuracies5 = AverageMeter()

    for i, (inp, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        inp = inp.cuda().detach()

        with torch.no_grad():
            teacher_out = teacher(inp)

        with torch.set_grad_enabled(train):
            student_out = student(inp)
            loss = kd_loss(student_out, teacher_out, target)
            if train:
                loss.backward()
                optimizer.step()
                scheduler.step()
                student.zero_grad()

        with torch.no_grad():
            prec1, prec5 = accuracy(student_out, target, topk=(1, 5))
            losses.update(loss.item(), inp.size(0))
            accuracies1.update(prec1[0], inp.size(0))
            accuracies5.update(prec5[0], inp.size(0))

    return losses.avg, accuracies1.avg, accuracies5.avg


def get_section(student, s_idx):
    return getattr(student, 'layer{}'.format(s_idx + 1))


def lit_epoch(args, epoch, logger, teacher, student, loader,
            optimizer, scheduler, ir_loss, kd_loss, save_path, train=True):
    teacher.eval()
    if train:
        # scheduler.step()
        student.eval()
        for p in student.parameters():
            p.requires_grad = False # set everything to false
        for s_idx in range(NB_SECTIONS):
            section = get_section(student, s_idx)
            section.train()
            for p in section.parameters():
                p.requires_grad = True
    else:
        student.eval()

    full_loss_log = AverageMeter()
    accuracies1 = AverageMeter()
    accuracies5 = AverageMeter()

    for i, (inp, target) in enumerate(loader):
        target = target.cuda(non_blocking=True)
        inp = inp.cuda(non_blocking=True).detach()

        with torch.no_grad():
            # Get the teacher intermediate reps
            features, soft_targets = teacher(inp, get_features=True)

        with torch.set_grad_enabled(train):
            # Do the full backward pass
            student_out = student(inp)
            full_loss = kd_loss(student_out, soft_targets, target) * args.beta

            if train:
                # full_loss.backward(retain_graph=True)
                # Now do section wise backwards
                for s_idx in range(NB_SECTIONS):
                    if args.use_noise:
                        noise = features[s_idx].data.new(features[s_idx].size()).normal_(0.0, 0.1)
                        sinp = features[s_idx] + noise
                    else:
                        sinp = features[s_idx]
                    section_out = get_section(student, s_idx)(sinp)
                    section_loss = ir_loss(section_out, features[s_idx + 1])
                    full_loss += section_loss * (1.0 - args.beta)
                full_loss /= 2
                full_loss.backward()
                optimizer.step()
                scheduler.step()
                student.zero_grad()

        with torch.no_grad():
            prec1, prec5 = accuracy(student_out, target, topk=(1, 5))
            full_loss_log.update(full_loss.item(), inp.size(0))
            accuracies1.update(prec1[0], inp.size(0))
            accuracies5.update(prec5[0], inp.size(0))
        
        num_updates = list(np.arange(10, 110, 10))
        if args.save_per_batch and train and (epoch==0) and (i+1 in num_updates):
            ver = 'train' # if train else 'val'
            logger.update_config({
                f'num_updates_{i+1}_{ver}_loss': full_loss_log.avg,
                f'num_updates_{i+1}_{ver}_top1': accuracies1.avg,
                f'num_updates_{i+1}_{ver}_top5': accuracies5.avg
            })
            target_path = os.path.join(save_path, f'batch-{i+1}')
            save_model({'epoch': epoch+1,
                        'optimizer': optimizer.state_dict(),
                        'state_dict': student.state_dict(),
                        }, target_path, is_best=False, model_name=None)


    return full_loss_log.avg, accuracies1.avg, accuracies5.avg


def setup_teacher_student(args, n_classes, device):
    ## Ours

    tc_stage_channel_widths = [int(w*args.channel_mul) for w in args.tc_stage_default_channel_widths]
    tc_depth_config = args.tc_stage_num * [args.tc_stage_depth]
    tc_channel_widths = [[w] * args.tc_stage_depth for w in tc_stage_channel_widths]
        
    teacher_model = get_resnet(n_classes, 
                                depth_config=tc_depth_config, 
                                channel_widths=tc_channel_widths, 
                                stage_strides=args.tc_stage_strides,
                                tc_stage_channel_widths=tc_stage_channel_widths
                                )


    if 'tiny_imagenet' in args.ds_name:
        mode = 'meta_train'
    else:
        mode = 'meta_test'
    tc_net_ckpt_path = load_tc_ckpt_path(
        mode=mode,
        tc_net_name=args.tc_net_name,
        image_size=args.image_size,
        ds_name=args.ds_name,
        ds_split=args.ds_split,
        channel_mul=args.channel_mul)
    teacher_model.load_state_dict(torch.load(tc_net_ckpt_path)['state_dict'])

    net_info_path = args.net_info_path
    net_info_list = torch.load(net_info_path)
    if args.ds_name == 'tiny_imagenet':
        net_index = args.net_index #* 50 + (int(args.ds_split))
    else:
        net_index = args.net_index
    net_info = net_info_list[args.net_index]
    flops = net_info[0]
    params = net_info[1]
    st_depth_config = net_info[2]
    st_channel_widths = net_info[3]
    student_model = get_resnet(n_classes, 
                            depth_config=st_depth_config, 
                            channel_widths=st_channel_widths, 
                            stage_strides=args.tc_stage_strides, 
                            tc_stage_channel_widths=tc_stage_channel_widths)

    ## Parameter Reampping
    ## Copy and Paste Stem and Tail
    student_model.conv1.load_state_dict(teacher_model.conv1.state_dict())
    student_model.bn1.load_state_dict(teacher_model.bn1.state_dict())
    student_model.fc.load_state_dict(teacher_model.fc.state_dict())
    ## Remap Parameters stage-wisely
    param_remapper = PR(device=device, n_stage=NB_SECTIONS, tc_net=teacher_model, 
                        st_net=student_model,st_depth_config=st_depth_config,
                        st_channel_widths=st_channel_widths, pr_type=args.pr_type, args=args)
    st_stages = [student_model.layer1, student_model.layer2, student_model.layer3, student_model.layer4]
    st_dict_lists = param_remapper.param_remapping()
    for i in range(NB_SECTIONS):
        print(f'=> parameter remapping for stage {i}')
        for d in range(st_depth_config[i]):
            st_stages[i][d].load_state_dict(st_dict_lists[i][d])                

    teacher_model = teacher_model.to(device)
    student_model = student_model.to(device)

    # freeze teacher model
    for p in teacher_model.parameters():
        p.requires_grad = False

    return teacher_model, student_model, net_info


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



if __name__ == '__main__':
    main()