import os
import tqdm
import torch
import shutil
import argparse
import warnings
from eval import get_acc
from model import mlpGPE
from torch import nn, optim
from loss import distance_loss
import torch.distributed as dist
from data import BufferILDataset
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
warnings.filterwarnings('ignore')


def main(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
    if opt.local_rank == 0 and opt.build_tensorboard:
        shutil.rmtree(opt.logdir, True)
        writer = SummaryWriter(logdir=opt.logdir)
        opt.build_tensorboard = False
    
    dist.init_process_group(backend='nccl', init_method=opt.init_method, world_size=opt.n_gpus)

    batch_size = opt.batch_size
    device = torch.device('cuda', opt.local_rank if torch.cuda.is_available() else 'cpu')
    print('Using device:{}'.format(device))

    train_set = BufferILDataset(task_id=opt.task_id, train=True, buffer_size=opt.buffer_size)
    val_set = BufferILDataset(task_id=opt.task_id, train=False)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=train_sampler, num_workers=36)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_set, shuffle=False)
    val_loader = DataLoader(val_set, batch_size=batch_size, sampler=val_sampler, num_workers=24)
    
    model = mlpGPE(old_prototypes=opt.old_prototypes)

    if opt.local_rank == 0:
        try:
            state_dict = torch.load(opt.checkpoint, map_location='cpu')
            state_dict['old_prototypes'] = opt.old_prototypes
            model.load_state_dict(state_dict, strict=True)
        except:
            print('Training from scratch...')

    model = torch.nn.parallel.DistributedDataParallel(model.to(device), device_ids=[opt.local_rank], output_device=opt.local_rank, broadcast_buffers=False)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.halves, gamma=0.5)
    criterion = nn.CrossEntropyLoss()
    lam = opt.lam
    gamma = opt.gamma
    
    for epoch in range(opt.epoch):
        train_loader.sampler.set_epoch(epoch)

        # only tqdm in rank 0
        if opt.local_rank == 0:
            data_loader = tqdm.tqdm(train_loader)
        else:
            data_loader = train_loader
        
        train_loss, val_loss = 0, 0
        train_acc, val_acc = 0, 0

        model.train()
        for x, y in data_loader:
            x, y = x.float().to(device), y.long().to(device)
            predict = model(x)
            loss = criterion(predict, y)
            if opt.old_prototypes is not None:
                d_loss = lam * max(0, distance_loss(opt.old_prototypes.to(model.device).detach(), model.module.old_prototypes) - gamma)
                loss = loss + d_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predict_cls = torch.max(predict, dim=-1)
            train_acc += get_acc(predict_cls, y)

            if opt.old_prototypes is not None:
                with torch.no_grad():
                    current_lr = optimizer.state_dict()['param_groups'][0]['lr']
                    temp_dist = distance_loss(opt.old_prototypes.to(model.device).detach(), model.module.old_prototypes)
                    lam = max(1e-2, lam + current_lr * (temp_dist - gamma))
                model.train()

        # update learning rate
        scheduler.step()

        if opt.local_rank == 0 and epoch % 2 == 0:
            model.eval()
            with torch.no_grad():
                for x, y in tqdm.tqdm(val_loader):
                    x, y = x.float().to(device), y.long().to(device)
                    predict = model(x)
                    loss = criterion(predict, y)
                    val_loss += loss.item()
                    _, predict_cls = torch.max(predict, dim=-1)
                    val_acc += get_acc(predict_cls, y)

            train_loss = train_loss / len(train_loader)
            train_acc = train_acc / len(train_loader)

            val_loss = val_loss / len(val_loader)
            val_acc = val_acc / len(val_loader)

            print('EPOCH : %03d | Train Loss : %.4f | Train Acc : %.4f | Val Loss : %.4f | Val Acc : %.4f'
                % (epoch, train_loss, train_acc, val_loss, val_acc))

            if val_acc >= opt.best_acc:
                opt.best_acc = val_acc
                model_name = 'epoch_%d_val_%.4f.pth' % (epoch, val_acc)
                os.makedirs(opt.save_path, exist_ok=True)
                torch.save(model.module.state_dict(), '%s/%s' % (opt.save_path, model_name))

            writer.add_scalar('train/loss', train_loss, epoch)
            writer.add_scalar('train/acc', train_acc, epoch)

            writer.add_scalar('val/loss', val_loss, epoch)
            writer.add_scalar('val/acc', val_acc, epoch)


if __name__ == '__main__':
    parser = argparse.ArgumentParser('DynamicGPE')
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument('--init_method', default='env://')
    parser.add_argument('--n_gpus', type=int, default=8)
    parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7')
    parser.add_argument('--build_tensorboard', type=bool, default=True)

    parser.add_argument('--epoch', type=int, default=30000)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--halves', type=int, default=70)
    parser.add_argument('--buffer_size', type=int, default=200)
    parser.add_argument('--best_acc', type=float, default=0.0)
    parser.add_argument('--lam', type=float, default=10.0)
    parser.add_argument('--gamma', type=float, default=0.01)
    parser.add_argument('--task_id', type=int, default=11)
    parser.add_argument('--logdir', type=str, default='./tensorboard/mlp_buffer/stage_11/')
    parser.add_argument('--save_path', type=str, default='./saved_models/mlp_buffer/stage_11/')
    # parser.add_argument('--old_prototypes', default=None)

    # checkpoint = './saved_models/mlp_buffer/stage_1/epoch_40_val_0.9673.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_2/epoch_44_val_0.9463.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_3/epoch_2_val_0.7830.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_4/epoch_0_val_0.8609.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_5/epoch_8_val_0.8606.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_6/epoch_0_val_0.8744.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_7/epoch_2_val_0.8763.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_8/epoch_0_val_0.8852.pth'
    # checkpoint = './saved_models/mlp_buffer/stage_9/epoch_0_val_0.8929.pth'
    checkpoint = './saved_models/mlp_buffer/stage_10/epoch_4_val_0.8746.pth'
    

    state_dict = torch.load(checkpoint, map_location='cpu')
    p = state_dict['new_prototypes']
    if 'old_prototypes' in state_dict.keys():
        p2 = state_dict['old_prototypes']
        p = torch.cat([p, p2], dim=1)

    parser.add_argument('--checkpoint', type=str, default=checkpoint)
    parser.add_argument('--old_prototypes', default=p.shape)

    opt = parser.parse_args()
    if opt.local_rank == 0:
        print('opt:', opt)

    opt.old_prototypes = p

    main(opt)


# using following script to train the model
# python -m torch.distributed.launch --nproc_per_node=8 mlp_buffer_train.py --n_gpus=8
