import os
import torch
import wandb
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from copy import deepcopy
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim import SGD, Adam
from collections import OrderedDict
import torch.multiprocessing as torchmp
from torch.optim.lr_scheduler import ExponentialLR, StepLR

from utils import InfIterator
from datasets import get_dataset

__all__ = ['ProtoNetTrainer']

use_wandb = False

class ProtoNetTrainer(object):
    def __init__(self, model=None, optimizer=None, interpolator=None, optimizer_interp=None, scheduler=None, trainloader=None, validloader=None, testloader=None, args=None):
        self.args = args
        self.model = model
        self.scheduler = scheduler
        self.optimizer = optimizer
        self.interpolator = interpolator
        self.optimizer_interp = optimizer_interp

        self.testloader = testloader
        self.trainloader = trainloader
        self.validloader = validloader

        self.create_logging_dirs()

    def create_logging_dirs(self):
        paths = ['checkpoints', 'logs']
        self.checkpoints_base = 'checkpoints/{}/{}/{}/{}/{}way{}shot'.format(self.model.name, self.trainloader.dataset.name, self.args.run, self.args.optimizer, self.args.ways, self.args.shots)
        self.logs_base = '{}/{}/{}/{}way{}shot'.format(self.model.name, self.trainloader.dataset.name, self.args.optimizer, self.args.ways, self.args.shots)

        if 'settaskinterpolator' in self.model.name:
            self.checkpoints_base = '{}/layer_{}_{}_{}_{}_{}_{}T'.format(self.checkpoints_base, self.args.ilayer, self.interpolator.signature,
                    self.args.inner_episodes, self.args.outer_episodes, self.args.BS, self.args.num_tasks)
            self.logs_base = '{}/layer_{}_{}_{}_{}_{}_{}T'.format(self.logs_base, self.args.ilayer, self.interpolator.signature,
                    self.args.inner_episodes, self.args.outer_episodes, self.args.BS, self.args.num_tasks)
        elif 'mlti' in self.model.name:
            self.checkpoints_base = '{}_{}'.format(self.checkpoints_base, self.args.alpha)
            self.logs_base = '{}_{}'.format(self.logs_base, self.args.alpha)

        paths.append(self.checkpoints_base)
        for path in paths:
            if not os.path.exists(path):
                os.makedirs(path)
        
        print(self.logs_base)
        if use_wandb:
            wandb.init(entity=self.args.wandb_entity, project=self.args.wandb_project, name=self.logs_base.replace('/', '_'), reinit=True)

    def test(self, model=None, dataloader=None, interpolator=None, device=None):
        model.eval()
        if interpolator is not None:
            interpolator.eval()

        count, losses = 0.0, 0.0
        accuracies = []
        with torch.no_grad():
            for support, slabel, query, qlabel in tqdm(dataloader, total=len(dataloader), ncols=75, leave=False):
                support, slabel = support.to(device), slabel.to(device)
                query, qlabel = query.to(device), qlabel.to(device)

                loss, acc = model(support=support, slabel=slabel, query=query, qlabel=qlabel, interpolator=interpolator)

                count += query.size(0) 
                losses += loss.item()
                accuracies.extend(acc)
        accuracy = np.mean(accuracies)
        return losses / count, accuracy 

    def train(self, model, optimizer, interpolator, optimizer_interp, support, slabel, query, qlabel, device):
        support, slabel = support.to(device), slabel.to(device)
        query, qlabel = query.to(device), qlabel.to(device)
        
        model.train()
        optimizer.zero_grad()
        if optimizer_interp is not None:
            interpolator.train()
            optimizer_interp.zero_grad()

        loss, acc = model(support=support, slabel=slabel, query=query, qlabel=qlabel, interpolator=interpolator)
        loss.backward()
        optimizer.step()
        if optimizer_interp is not None:
            optimizer_interp.step()
        return loss, acc

    def hyperparameter_settaskinterpolator(self):
        def approxInverseHVP(v, f, w, i=5, alpha=0.1):
            p = [v_.clone().detach() for v_ in v]
            for j in range(i):
                grad = torch.autograd.grad(f, w(), grad_outputs=v, retain_graph=True)
                v = [v_ - alpha * g_ for v_,g_ in zip(v, grad)]
                p = [v_ + p_ for v_,p_ in zip(v, p)]
            return [alpha * p_ for p_ in p]

        def hypergradients(L_V, L_T, lmbda, w, i=5, alpha=0.1):
            v1 = torch.autograd.grad(L_V, w(), retain_graph=True)

            d_LT_dw = torch.autograd.grad(L_T, w(), create_graph=True)
            v2 = approxInverseHVP(v=v1, f=d_LT_dw, w=w, i=i, alpha=alpha)

            v3 = torch.autograd.grad(d_LT_dw, lmbda(), grad_outputs=v2, retain_graph=True)
            
            d_LV_dlmbda = torch.autograd.grad(L_V, lmbda())

            return [d - v for d,v in zip(d_LV_dlmbda, v3)]
            
        def set_train(model, optimizer, interpolator, support, slabel, query, qlabel, device, interp_loss=False):
            model.train(); interpolator.train()

            if interpolator.name == 'double_forward':
                p = support.size(1)
                loss_ratio = 1.0 / p
                
                train_loss, train_acc = 0.0, 0.0

                support, slabel, query, qlabel = support.to(device), slabel.to(device), query.to(device), qlabel.to(device)
                for i in range(p):
                    index = torch.randperm(p)[:i+1]
                    s, sl = support[:, index, :, :, :].contiguous(), slabel[:, index, :, :].contiguous()
                    q, ql = query[:, index, :, :, :].contiguous(), qlabel[:, index, :, :].contiguous()

                    if i == 0:
                        s, sl = s.squeeze(1), sl.squeeze(1)
                        q, ql = q.squeeze(1), ql.squeeze(1)
                    t_l, t_a = self.model(support=s, slabel=sl, query=q, qlabel=ql, interpolator=interpolator)
                    
                    train_loss += t_l*loss_ratio
                    train_acc  += np.mean(t_a)*loss_ratio
            else:
                raise NotImplementedError('{} interpolator not implemented'.format(interpolator.name))

            if optimizer is not None:
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
            return train_loss, train_acc

        best_valid_acc = 0.0
        best_valid_state_dict_model = None
        best_valid_state_dict_interpolator = None
        trainloader = InfIterator(self.trainloader)
        validloader = InfIterator(self.validloader)
       
        for episode in tqdm(range(self.args.outer_episodes), ncols=75, leave=False):
            train_losses, train_accs = [], []
            for k in tqdm(range(self.args.inner_episodes), ncols=75, leave=False):
                support, slabel, query, qlabel = next(trainloader)
                train_loss, train_acc = set_train(model=self.model, optimizer=self.optimizer, interpolator=self.interpolator,\
                        support=support, slabel=slabel, query=query, qlabel=qlabel, device=self.args.device, interp_loss=True)
                train_losses.append(train_loss.item()); train_accs.append(np.mean(train_acc))
            if use_wandb:
                wandb.log({'train_loss': np.mean(train_losses), 'train_accuracy': np.mean(train_accs)}, step=episode)
            
            support_t, slabel_t, query_t, qlabel_t = self.trainloader.dataset.get_batch(batch_size=50*self.args.train_batch_size)
            L_T, _ = set_train(model=self.model, optimizer=None, interpolator=self.interpolator, support=support_t, slabel=slabel_t, query=query_t, qlabel=qlabel_t, device=self.args.device)
            
            self.model.eval(); self.interpolator.eval()
            support_v, slabel_v, query_v, qlabel_v = self.validloader.dataset.get_batch(batch_size=self.args.BS*self.args.batch_size)
            L_V, _ = self.model(support=support_v.to(self.args.device), slabel=slabel_v.to(self.args.device), query=query_v.to(self.args.device),\
                    qlabel=qlabel_v.to(self.args.device), interpolator=self.interpolator)

            hgrads = hypergradients(L_V=L_V, L_T=L_T, lmbda=self.interpolator.parameters, w=self.model.parameters, i=5, alpha=self.args.lr)
            
            self.optimizer_interp.zero_grad()
            for p, g in zip(self.interpolator.parameters(), hgrads):
                hypergrad = torch.clamp(g, -5.0, 5.0)
                hypergrad *= 1.0 - (episode / (self.args.outer_episodes))
                p.grad.copy_(hypergrad)
            self.optimizer_interp.step()
            
            valid_loss, valid_acc = self.test(self.model, self.validloader, interpolator=self.interpolator, device=self.args.device)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                best_valid_state_dict_model = deepcopy(self.model.state_dict())
                best_valid_state_dict_interpolator = deepcopy(self.interpolator.state_dict())
            if use_wandb:
                wandb.log({'valid_loss': valid_loss , 'valid_accuracy': valid_acc}, step=episode)

        self.model.load_state_dict(best_valid_state_dict_model)
        self.interpolator.load_state_dict(best_valid_state_dict_interpolator)
        test_loss, test_acc = self.test(self.model, self.testloader, interpolator=self.interpolator, device=self.args.device)
        state_dict = {'model': self.model.state_dict(), 'interpolator': self.interpolator.state_dict()}
        torch.save(state_dict, '{}/state.pth'.format(self.checkpoints_base))
        if use_wandb:
            wandb.log({'test_loss': test_loss , 'test_accuracy': test_acc}, step=episode + 1)
        return test_loss, test_acc
    
    def fit(self):
        settaskinterpolator_models = ['settaskinterpolator_protonet']
        if self.model.name in settaskinterpolator_models:
            test_loss, test_acc = self.hyperparameter_settaskinterpolator()
        else:
            raise NotImplementedError('{} not implemented'.format(self.model.name))
        if use_wandb:
            wandb.join()
        return test_loss, test_acc
