import os
from collections import OrderedDict
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import TensorDataset, DataLoader

from module import MimgnetEncoder
from data import MimgnetDataset
from utils import NTXentLoss

class Trainer(object):
    def __init__(self, args):
        self.args = args

        dataset = MimgnetDataset("../data/mimgnet/train.npy")
        self.trloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            drop_last=True
        )

        self.encoder = MimgnetEncoder(
            hidden_size=args.hidden_size
        ).to(args.device)
        last_hidden_size = 5*5*args.hidden_size
        self.l1 = nn.Linear(last_hidden_size, last_hidden_size).to(args.device)
        self.l2 = nn.Linear(last_hidden_size, int(0.5*last_hidden_size)).to(args.device)

        self.optimizer = torch.optim.Adam(
            list(self.encoder.parameters())+list(self.l1.parameters())+list(self.l2.parameters()),
            lr=args.lr
        )

        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, 
            T_max=len(self.trloader), 
            eta_min=0,
            last_epoch=-1
        )

        self.criterion = NTXentLoss(
            device=args.device, 
            batch_size=args.batch_size, 
            temperature=0.5, 
            use_cosine_similarity=True
        )

    def train(self):
        global_step = 0
        best_loss = 1000000.0

        for global_epoch in range(self.args.train_epochs):
            avg_loss = []
            with tqdm(total=len(self.trloader)) as pbar:
                for xis, xjs in self.trloader:
                    self.encoder.train()
                    self.l1.train()
                    self.l2.train()

                    self.encoder.zero_grad()
                    self.l1.zero_grad()
                    self.l2.zero_grad()

                    xis = xis.to(self.args.device)
                    xjs = xjs.to(self.args.device)

                    zis = self.l2(F.relu(self.l1(self.encoder(xis))))
                    zjs = self.l2(F.relu(self.l1(self.encoder(xjs))))

                    # normalize projection feature vectors
                    zis = F.normalize(zis, dim=1)
                    zjs = F.normalize(zjs, dim=1)

                    loss = self.criterion(zis, zjs)           
                    loss.backward()          
                    self.optimizer.step()

                    postfix = OrderedDict(
                        {'loss': '{0:.4f}'.format(loss)}
                    )
                    pbar.set_postfix(**postfix)

                    pbar.update(1)
                    global_step += 1
                    avg_loss.append(loss.item())

                    if self.args.debug:
                        break

            if global_epoch >= 10:
                self.scheduler.step()

            avg_loss = np.mean(avg_loss)

            if best_loss > avg_loss:
                best_loss = avg_loss
                state = {
                    'encoder_state_dict': self.encoder.state_dict(),
                    'loss': avg_loss,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, 'best.pth'))

                self.save_feature()

            print("{0}-th EPOCH Loss: {1:.4f}, BEST Loss: {2:.4f}".format(global_epoch, avg_loss, best_loss))

            if self.args.debug:
                break
    def save_feature(self):
        self.encoder.eval()

        dataset = MimgnetDataset("../data/mimgnet/train.npy", False)
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.args.batch_size,
            shuffle=False
        )

        os.makedirs(self.args.feature_save_dir, exist_ok=True)

        features = []
        print("START ENCODING TRAINING SET")
        with tqdm(total=len(loader)) as pbar:
            for x in loader:
                x = x.to(self.args.device)
                f = self.encoder(x)
                features.append(f.detach().cpu().numpy())
                pbar.update(1)
        features = np.concatenate(features, axis=0)
        print("SAVE ({0}, {1}) shape array".format(features.shape[0], features.shape[1]))
        np.save(os.path.join(self.args.feature_save_dir, "train_features.npy"), features)

        dataset = MimgnetDataset("../data/mimgnet/test.npy", False)
        loader = DataLoader(
            dataset=dataset,
            batch_size=self.args.batch_size,
            shuffle=False
        )

        features = []
        print("START ENCODING TEST SET")
        with tqdm(total=len(loader)) as pbar:
            for x in loader:
                x = x.to(self.args.device)
                f = self.encoder(x)
                features.append(f.detach().cpu().numpy())
                pbar.update(1)
        features = np.concatenate(features, axis=0)
        print("SAVE ({0}, {1}) shape array".format(features.shape[0], features.shape[1]))
        np.save(os.path.join(self.args.feature_save_dir, "test_features.npy"), features)