import os
from collections import OrderedDict
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

from model import GMVAE
from data import Data

class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.data = Data(args.feature_dir)

        dataset = TensorDataset(torch.from_numpy(self.data.x_mtr).float())
        sampler = RandomSampler(dataset, replacement=True)
        self.trloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size*args.sample_size,
            sampler=sampler,
            drop_last=True
        )

        self.input_shape = self.data.x_mtr.shape[-1]
        
        self.model = GMVAE(
            input_shape=self.input_shape,
            unsupervised_em_iters=args.unsupervised_em_iters,
            semisupervised_em_iters=args.semisupervised_em_iters,
            fix_pi=args.fix_pi,   
            component_size=args.way,        
            latent_size=args.latent_size, 
            train_mc_sample_size=args.train_mc_sample_size,
            test_mc_sample_size=args.test_mc_sample_size
        ).to(self.args.device)

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=args.lr
        )

        self.writer = SummaryWriter(
            log_dir=os.path.join(args.save_dir, "tb_log")
        )

    def train(self):        
        global_epoch = 0
        global_step = 0
        best_1shot = 0.0
        best_5shot = 0.0
        best_20shot = 0.0
        best_50shot = 0.0
        iterator = iter(self.trloader)

        while (global_epoch * self.args.freq_iters < self.args.train_iters):
            with tqdm(total=self.args.freq_iters) as pbar:
                for _ in range(self.args.freq_iters):
                    #if global_step == int(self.args.train_iters/3) or global_step == int(2*self.args.train_iters/3):
                    #    self.optimizer.param_groups[0]['lr'] *= 0.1
                        
                    self.model.train()
                    self.model.zero_grad()

                    try:
                        H = next(iterator)[0]
                    except StopIteration:
                        iterator = iter(self.trloader)
                        H = next(iterator)[0]
                                        
                    H = H.to(self.args.device).float()
                    H = H.view(self.args.batch_size, self.args.sample_size, self.input_shape)

                    rec_loss, kl_loss = self.model(H)
                    loss = rec_loss + kl_loss
                    
                    loss.backward()          
                    self.optimizer.step()

                    postfix = OrderedDict(
                        {'rec': '{0:.4f}'.format(rec_loss), 
                        'kld': '{0:.4f}'.format(kl_loss)
                        }
                    )
                    pbar.set_postfix(**postfix)                    
                    self.writer.add_scalars(
                        'train', 
                        {'rec': rec_loss, 'kld': kl_loss}, 
                        global_step
                    )

                    pbar.update(1)
                    global_step += 1

                    if self.args.debug:
                        break

            with torch.no_grad():
                mean_1shot, std_1shot = self.eval(shot=1)
                mean_5shot, std_5shot = self.eval(shot=5)
                mean_20shot, std_20shot = self.eval(shot=20)
                mean_50shot, std_50shot = self.eval(shot=50)
            
            self.writer.add_scalars(
                'test', 
                {'1shot-acc-mean': mean_1shot, '1shot-acc-std': std_1shot,
                 '5shot-acc-mean': mean_5shot, '5shot-acc-std': std_5shot,
                 '20shot-acc-mean': mean_20shot, '5shot-acc-std': std_20shot,
                 '50shot-acc-mean': mean_50shot, '5shot-acc-std': std_50shot}, 
                global_epoch
            )

            if best_1shot < mean_1shot:
                best_1shot = mean_1shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_1shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '1shot_best.pth'))

            print("1shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_1shot, best_1shot))

            if best_5shot < mean_5shot:
                best_5shot = mean_5shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_5shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '5shot_best.pth'))

            print("5shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_5shot, best_5shot))

            if best_20shot < mean_20shot:
                best_20shot = mean_20shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_20shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '20shot_best.pth'))

            print("20shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_20shot, best_20shot))

            if best_50shot < mean_50shot:
                best_50shot = mean_50shot
                state = {
                    'state_dict': self.model.state_dict(),
                    'accuracy': mean_50shot,
                    'epoch': global_epoch,
                }
                torch.save(state, os.path.join(self.args.save_dir, '50shot_best.pth'))

            print("50shot {0}-th EPOCH Accuracy: {1:.4f}, BEST Accuracy: {2:.4f}".format(global_epoch, mean_50shot, best_50shot))

            global_epoch += 1

    def eval(self, shot):
        
        self.model.eval()
        all_accuracies = np.array([])
        while(True):
            H_tr, y_tr, H_te, y_te = self.data.generate_test_episode(
                way=self.args.way,
                shot=shot,
                query=self.args.query,
                n_episodes=self.args.batch_size
            )
            H_tr = torch.from_numpy(H_tr).to(self.args.device).float()
            y_tr = torch.from_numpy(y_tr).to(self.args.device)
            H_te = torch.from_numpy(H_te).to(self.args.device).float()
            y_te = torch.from_numpy(y_te).to(self.args.device)

            if len(all_accuracies) >= self.args.eval_episodes or self.args.debug:
                break
            else:
                y_te_pred = self.model.prediction(H_tr, y_tr, H_te)
                accuracies = torch.mean(torch.eq(y_te_pred, y_te).float(), dim=-1).cpu().numpy()
                all_accuracies = np.concatenate([all_accuracies, accuracies], axis=0)
        
        all_accuracies = all_accuracies[:self.args.eval_episodes]
        return np.mean(all_accuracies), np.std(all_accuracies)
    