import os
from collections import OrderedDict
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

from model import GMVAE
from data import Data
from utils import make_grid_with_labels

class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.data = Data(args.data_dir)

        dataset = TensorDataset(torch.from_numpy(self.data.x_mtr).float())
        sampler = RandomSampler(dataset, replacement=True)
        self.trloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size*args.sample_size,
            sampler=sampler,
            drop_last=True
        )

        self.input_shape = [1, 28, 28]

        self.model = GMVAE(
            input_shape=self.input_shape,
            unsupervised_em_iters=args.unsupervised_em_iters,
            semisupervised_em_iters=args.semisupervised_em_iters,
            fix_pi=args.fix_pi,
            hidden_size=args.hidden_size,
            component_size=args.way,        
            latent_size=args.latent_size, 
            train_mc_sample_size=args.train_mc_sample_size,
            test_mc_sample_size=args.test_mc_sample_size
        ).to(self.args.device)

        if args.state_dir:
            state_dict = torch.load(args.state_dir)["state_dict"]
            self.model.load_state_dict(state_dict)
            print("Pretrained Model Loaded.")
            print("Only Meta-testing is available.")
            self.model.uniform_pi = torch.ones(20).to(self.args.device)/20

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=args.lr
        )

        self.writer = SummaryWriter(
            log_dir=os.path.join(args.save_dir, "tb_log")
        )

        self.tr_image_save_dir = os.path.join(self.args.save_dir, 'tr_samples')
        os.makedirs(self.tr_image_save_dir, exist_ok=True)
        self.te_image_save_dir = os.path.join(self.args.save_dir, 'te_samples')
        os.makedirs(self.te_image_save_dir, exist_ok=True)

    def train(self):        
        global_epoch = 0
        global_step = 0
        best_1shot = 0.0
        best_5shot = 0.0
        iterator = iter(self.trloader)

        while (global_epoch * self.args.freq_iters < self.args.train_iters):
            if self.args.state_dir:
                break
            with tqdm(total=self.args.freq_iters) as pbar:
                for _ in range(self.args.freq_iters):
                    self.model.train()
                    self.model.zero_grad()

                    try:
                        X = next(iterator)[0]
                    except StopIteration:
                        iterator = iter(self.trloader)
                        X = next(iterator)[0]
                                        
                    X = X.to(self.args.device).float()
                    X = X.view(self.args.batch_size, self.args.sample_size, *self.input_shape)

                    rec_loss, kl_loss = self.model(X)
                    loss = rec_loss + kl_loss
                    
                    loss.backward()          
                    self.optimizer.step()

                    postfix = OrderedDict(
                        {'rec': '{0:.4f}'.format(rec_loss), 
                        'kld': '{0:.4f}'.format(kl_loss)
                        }
                    )
                    pbar.set_postfix(**postfix)                    
                    self.writer.add_scalars(
                        'train', 
                        {'rec': rec_loss, 'kld': kl_loss}, 
                        global_step
                    )

                    pbar.update(1)
                    global_step += 1

                    if self.args.debug:
                        break

            y = self.model.tr_prediction(X)
            True_X, y = X[0], y[0]
            self.save_tr_img(True_X, y)

            with torch.no_grad():
                mean_1shot, std_1shot, True_X_1shot, Recon_X_1shot, Component_X_1shot = self.eval(shot=1)
                mean_5shot, std_5shot, True_X_5shot, Recon_X_5shot, Component_X_5shot = self.eval(shot=5)
            
            self.writer.add_scalars(
                'test', 
                {'1shot-acc-mean': mean_1shot, '1shot-acc-std': std_1shot,
                 '5shot-acc-mean': mean_5shot, '5shot-acc-std': std_5shot}, 
                global_epoch
            )

            if best_1shot < mean_1shot:
                best_1shot = mean_1shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_1shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '1shot_best.pth'))
                
                self.save_te_img(
                    shot=1,
                    True_X=True_X_1shot,
                    Recon_X=Recon_X_1shot,
                    Component_X=Component_X_1shot
                )

            print("1shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_1shot, best_1shot))

            if best_5shot < mean_5shot:
                best_5shot = mean_5shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_5shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '5shot_best.pth'))

                self.save_te_img(
                    shot=5,
                    True_X=True_X_5shot,
                    Recon_X=Recon_X_5shot,
                    Component_X=Component_X_5shot
                )                

            print("5shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_5shot, best_5shot))

            global_epoch += 1
        
        with torch.no_grad():
            mean_1shot, _, _, _, _ = self.eval(shot=1)
            mean_5shot, _, _, _, _ = self.eval(shot=5)
        print("1shot Final Accuracy: {0:.4f}".format(mean_1shot))
        print("5shot Final Accuracy: {0:.4f}".format(mean_5shot))

    def eval(self, shot):
        
        self.model.eval()
        all_accuracies = np.array([])
        while(True):
            X_tr, y_tr, X_te, y_te = self.data.generate_test_episode(
                way=self.args.way,
                shot=shot,
                query=self.args.query,
                n_episodes=self.args.batch_size
            )
            X_tr = torch.from_numpy(X_tr).to(self.args.device).float()
            y_tr = torch.from_numpy(y_tr).to(self.args.device)
            X_te = torch.from_numpy(X_te).to(self.args.device).float()
            y_te = torch.from_numpy(y_te).to(self.args.device)

            if len(all_accuracies) >= self.args.eval_episodes or self.args.debug:
                True_X, Recon_X, Component_X = self.model.generate(X_tr, y_tr, X_te)
                break
            else:
                y_te_pred = self.model.te_prediction(X_tr, y_tr, X_te)
                accuracies = torch.mean(torch.eq(y_te_pred, y_te).float(), dim=-1).cpu().numpy()
                all_accuracies = np.concatenate([all_accuracies, accuracies], axis=0)
        
        all_accuracies = all_accuracies[:self.args.eval_episodes]
        return np.mean(all_accuracies), np.std(all_accuracies), True_X, Recon_X, Component_X

    def save_tr_img(self, True_X, y):

        for i in range(self.args.way):
            if len(True_X[y==i]) == 0:
                True_images = torchvision.utils.make_grid(torch.zeros((1, 1, 28, 28)), nrow=1, padding=2, pad_value=255)
                torchvision.utils.save_image(
                    True_images,
                    os.path.join(self.tr_image_save_dir, '{0}th_True.png'.format(i+1))
                )
            else:
                True_images = torchvision.utils.make_grid(True_X[y==i], nrow=20, padding=2, pad_value=255)
                torchvision.utils.save_image(
                    True_images,
                    os.path.join(self.tr_image_save_dir, '{0}th_True.png'.format(i+1))
                )

    def save_te_img(self, shot, True_X, Recon_X, Component_X):

        True_images = torchvision.utils.make_grid(True_X, nrow=8, padding=2, pad_value=255)
        torchvision.utils.save_image(
            True_images,
            os.path.join(self.te_image_save_dir, '{0}shot_True.png'.format(shot))
        )
        Recon_images = torchvision.utils.make_grid(Recon_X, nrow=8, padding=2, pad_value=255)
        torchvision.utils.save_image(
            Recon_images,
            os.path.join(self.te_image_save_dir, '{0}shot_Recon.png'.format(shot))
        )
        Component_images = torchvision.utils.make_grid(Component_X, nrow=8, padding=2, pad_value=255)
        torchvision.utils.save_image(
            Component_images,
            os.path.join(self.te_image_save_dir, '{0}shot_Component.png'.format(shot))
        )
