import argparse
import copy
import os
import os.path as osp
import time
import warnings

import torch
from torch.utils.data import Dataset, DataLoader

from linear_converter import LinearConverter


class FeatureDataset(Dataset):
    def __init__(self, student_path, teacher_path, num_samples=1000):
        self.student_path = student_path
        self.teacher_path = teacher_path
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        student_feature_path = osp.join(self.student_path, '{:06}.pth'.format(idx))
        student_features = torch.load(student_feature_path)
        teacher_feature_path = osp.join(self.teacher_path, '{:06}.pth'.format(idx))
        teacher_features = torch.load(teacher_feature_path)
        return student_features, teacher_features


def collate_features(batch):
    student_features = []
    n = len(batch[0][0])
    for i in range(n):
        student_features_i = [f[0][i] for f in batch]
        student_features_i = torch.stack(student_features_i)
        student_features.append(student_features_i)

    teacher_features = []
    n = len(batch[0][1])
    for i in range(n):
        teacher_features_i = [f[1][i] for f in batch]
        teacher_features_i = torch.stack(teacher_features_i)
        teacher_features.append(teacher_features_i)
    return student_features, teacher_features


def loss_function(approx_features, teacher_features, args):
    losses = []
    for f1, f2 in zip(approx_features[args.student_start:], teacher_features[args.teacher_start:]):
        losses.append(torch.nn.functional.mse_loss(f1, f2, reduction='mean'))
    return losses


def list_to_cuda(lt):
    return [x.cuda() for x in lt]


def train(model, train_loader, val_loader, optimizer, scheduler, args):
    for epoch in range(args.epochs):
        epoch_start = time.time()
        curr_time = time.time()
        for iter, (student_features, teacher_features) in enumerate(train_loader):
            student_features, teacher_features = list_to_cuda(student_features), list_to_cuda(teacher_features)
            data_time = time.time() - curr_time
            approx_features = model(student_features)
            losses = loss_function(approx_features, teacher_features, args)
            loss = sum(losses) / len(losses)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            iter_time = time.time() - curr_time
            curr_time = time.time()

            if iter % 10 == 0:
                print('Epoch {}, iter {}, train_loss_avg {:.4f}, train_loss_all {}, data_time {:.2f}, iter_time {:.2f}'.format(epoch, iter, loss.item(), losses, data_time, iter_time))

        train_time = time.time() - epoch_start
        print('Epoch {}, train_time {:.2f}'.format(epoch, train_time))
        scheduler.step()

        total_val_losses = [0.0] * (5 - abs(args.student_start - args.teacher_start))
        total_val_batches = 0
        for iter, (student_features, teacher_features) in enumerate(val_loader):
            student_features, teacher_features = list_to_cuda(student_features), list_to_cuda(teacher_features)
            with torch.no_grad():
                approx_features = model(student_features)
                losses = loss_function(approx_features, teacher_features, args)
                for i, loss in enumerate(losses):
                    total_val_losses[i] += loss
                total_val_batches += 1

        avg_val_losses = [loss / total_val_batches for loss in total_val_losses]
        avg_val_loss = sum(avg_val_losses) / len(avg_val_losses)
        print('Epoch {}, val_loss_avg {:.4f}, val_loss_all {}'.format(epoch, avg_val_loss.item(), avg_val_losses))

        ckpt = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'val_loss': avg_val_loss.item(),
        }
        torch.save(ckpt, osp.join(args.log, 'latest.pth'))

        epoch_time = time.time() - epoch_start
        print('Epoch {}, total_time {:.2f}'.format(epoch, epoch_time))
        epoch_start = time.time()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train a converter')
    # Data
    parser.add_argument('--student', type=str, default=None)
    parser.add_argument('--teacher', type=str, default=None)
    parser.add_argument('--train-samples', type=int, default=1000)
    parser.add_argument('--val-samples', type=int, default=1000)
    parser.add_argument('--batch-size', type=int, default=20)
    # Model
    parser.add_argument('--in_features', type=int, default=256)
    parser.add_argument('--out_features', type=int, default=256)
    parser.add_argument('--student_start', type=int, default=0)
    parser.add_argument('--teacher_start', type=int, default=0)
    # Optimization
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.05)
    # Others
    parser.add_argument('--log', type=str, default=None)
    args = parser.parse_args()

    train_student_path = osp.join(args.student, 'train')
    train_teacher_path = osp.join(args.teacher, 'train')
    train_set = FeatureDataset(train_student_path, train_teacher_path,
        num_samples=args.train_samples)
    train_loader = DataLoader(train_set, batch_size=args.batch_size,
        shuffle=True, num_workers=4, collate_fn=collate_features)

    val_student_path = osp.join(args.student, 'val')
    val_teacher_path = osp.join(args.teacher, 'val')
    val_set = FeatureDataset(val_student_path, val_teacher_path,
        num_samples=args.val_samples)
    val_loader = DataLoader(val_set, batch_size=args.batch_size,
        shuffle=False, num_workers=4, collate_fn=collate_features)

    model = LinearConverter(args.in_features, args.out_features,
        start=args.student_start,
        end=args.student_start - args.teacher_start + 5)
    model.cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    train(model, train_loader, val_loader, optimizer, scheduler, args)
