import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os
from data_prepare import *
from collections import defaultdict
from typing import Optional
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import CosineAnnealingLR
from typing import List, Tuple
from time import time
##########################################################################################################################
# Training image completion models in meta learning set-up
##########################################################################################################################
class EarlyStopping:
    def __init__(self, patience=5, min_delta=1e-4, mode='min', save_path='best_model.pth', verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.save_path = save_path
        self.verbose = verbose

        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0

       
        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)

    def __call__(self, current_score, model, epoch):
        if self.best_score is None:
            self.best_score = current_score
            self.save_checkpoint(model, epoch)
            return

        if self.mode == 'min':
            improved = current_score < (self.best_score - self.min_delta)
        else:
            improved = current_score > (self.best_score + self.min_delta)

        if improved:
            self.best_score = current_score
            self.best_epoch = epoch
            self.save_checkpoint(model, epoch)
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, model, epoch):
        torch.save(model.state_dict(), self.save_path)

    def load_best_model(self, model):
        if os.path.exists(self.save_path):
            model.load_state_dict(torch.load(self.save_path))
        else:
            raise FileNotFoundError(f"file don't exist：{self.save_path}")
        return model
def train_meta_net(
                   meta_net,
                   dim,
                   net_optim,
                   train_loader,
                   ):

    meta_net.train()
    epoch_train_nll = []
    for batch_idx, (y_all, _) in enumerate(train_loader):
        batch_size = y_all.shape[0]
        if dim == 28:
            y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 1).cuda()
        elif dim == 32:
            y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 3).cuda()
        N = random.randint(1, dim*dim-1)
        idx = get_context_idx(N,dim*dim)
        idx_list = idx.tolist()
        idx_all = np.arange(dim*dim).tolist()
        x_c = idx_to_x(idx, batch_size, dim)
        y_c = idx_to_y(idx, y_all)
        idx_all_tensor = torch.tensor(idx_all,dtype=torch.long).cuda()
        x = idx_to_x(idx_all_tensor, batch_size, dim).cuda()
        y = idx_to_y(idx_all_tensor, y_all).cuda()

        pred_idx = torch.tensor(list(set(idx_all)-set(idx_list)), dtype=torch.long).cuda()
        x_t = idx_to_x(pred_idx, batch_size, dim).cuda()
        y_t = idx_to_y(pred_idx, y_all).cuda()
        net_optim.zero_grad()
        if meta_net.type == 'OS-NPs':
            y_ = y.permute(0,2,1).view(y.size()[0],y.size()[2],dim,dim)
            n_total = dim*dim
            num_context = int(torch.empty(1).uniform_(1, n_total-1).item())
            mask = y_.new_empty(y_.size(0), 1, y_.size(2), y_.size(3)).bernoulli_(p=num_context / n_total)
            lnp,log_y_t = meta_net(y_,mask)
            lnp,nll_z = lnp.mean(),-log_y_t.mean(0)
            if meta_net.his_nll == None:
                meta_net.l_w0 = (-nll_z.detach()-torch.mean(-nll_z).detach())
                w_z = (meta_net.l_w0).softmax(-1)
                meta_net.w_z = w_z.detach()
            else:
                meta_net.delta = (meta_net.his_nll-nll_z.detach()).view(-1,1).squeeze()
                g1 = (torch.diag(meta_net.w_z)-meta_net.w_z@meta_net.w_z.T)@meta_net.delta
                g0 = meta_net.l_w0
                r0 = -0.01*(g1+0.01*g0)
                meta_net.l_w0 = meta_net.l_w0 + r0
                w_z = meta_net.l_w0.softmax(0).detach()
                meta_net.w_z = w_z
            meta_net.his_nll = nll_z.detach()
            rw_z = torch.sum(nll_z*w_z.detach())
            rw_z.backward()
            loss = rw_z
            if meta_net.type[0] == 'C':
                clip_grad_norm_(meta_net.parameters(), max_norm=1.0)
        else:
            y_ = y.permute(0,2,1).view(y.size()[0],y.size()[2],dim,dim)
            n_total = dim*dim
            num_context = int(torch.empty(1).uniform_(1, n_total-1).item())
            mask = y_.new_empty(y_.size(0), 1, y_.size(2), y_.size(3)).bernoulli_(p=num_context / n_total)
            loss = meta_net(y_,mask).mean()
            loss.backward()
            clip_grad_norm_(meta_net.parameters(), max_norm=1.0)
        net_optim.step()
        epoch_train_nll.append(loss.data.cpu())
        
    avg_tr_nll = np.array(epoch_train_nll).sum() / len(train_loader)
    return avg_tr_nll


def eval_meta_net(
                  meta_net,
                  dim,
                  eval_loader,
                  num_c_points=None):

    meta_net.eval()
    epoch_test_nll = []
    epoch_test_mse = []

    with torch.no_grad():
        for batch_idx, (y_all, _) in enumerate(eval_loader):
            batch_size = y_all.shape[0]
            if dim == 28:
                y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 1).cuda()
            elif dim == 32:
                y_all = y_all.permute(0,2,3,1).contiguous().view(batch_size, -1, 3).cuda()
            if num_c_points == None:
                N = random.randint(1, dim*dim)
            else:
                N = num_c_points

            idx = get_context_idx(N, dim*dim, order_pixels=False)
            idx_list = idx.tolist()
            idx_all = np.arange(dim*dim).tolist()
            x_c = idx_to_x(idx, batch_size, dim)
            y_c = idx_to_y(idx, y_all)
            idx_all_tensor = torch.tensor(idx_all,dtype=torch.long).cuda()
            x = idx_to_x(idx_all_tensor, batch_size, dim).cuda()
            y = idx_to_y(idx_all_tensor, y_all).cuda()

            pred_idx = torch.tensor(list(set(idx_all)-set(idx_list)), dtype=torch.long).cuda()
            x_t = idx_to_x(pred_idx, batch_size, dim).cuda()
            y_t = idx_to_y(pred_idx, y_all).cuda()
            y_ = y.permute(0,2,1).view(y.size()[0],y.size()[2],dim,dim)
            n_total = dim*dim
            num_context = int(torch.empty(1).uniform_(1, n_total-1).item())
            mask = y_.new_empty(y_.size(0), 1, y_.size(2), y_.size(3)).bernoulli_(p=num_context / n_total)
            mu, logvar, b_nll, y_mean=meta_net.conditional_predict(y_,mask)
            b_avg_nll = b_nll/(mask.size()[-1]*mask.size()[-2])
            b_avg_mse=F.mse_loss(y_mean,y)

            epoch_test_nll.append(b_avg_nll.cpu())
            epoch_test_mse.append(b_avg_mse.cpu())
            
    avg_te_nll = np.array(epoch_test_nll).sum() /len(eval_loader)
    avg_te_mse = np.array(epoch_test_mse).sum() /len(eval_loader)

    return avg_te_nll, avg_te_mse


def run_tr_te(args,
              meta_net,
              dim,
              net_optim,
              train_loader,
              val_loader,
              eval_loader,
              check_lvm,
              data,
              rand_eval,
              writer,seed):

    meta_tr_results, meta_te_nll_results, meta_te_mse_results = [], [], []
    save_path = './final_results_'+data+'/'+check_lvm+'/'+str(seed)+'/'+str(writer)
    os.makedirs(save_path, exist_ok=True)

    early_stopping = EarlyStopping(
    patience=int(args.patient_epochs),
    min_delta=1e-2,
    mode='min',
    save_path=save_path+'/'+check_lvm+'.pth',
    verbose=False
    )
    meta_net.T_step = args.epochs*len(train_loader)
    for epoch in range(1, args.epochs + 1):
        avg_tr_nll = train_meta_net(meta_net, dim, net_optim,
                                    train_loader)
        meta_tr_results.append(avg_tr_nll)

        if rand_eval == True:
            avg_te_nll, avg_te_mse = eval_meta_net(meta_net, dim, eval_loader,
                                                     num_c_points=None)
            avg_val_nll, avg_val_mse = eval_meta_net(meta_net, dim, val_loader,
                                                     num_c_points=None)
            meta_te_nll_results.append(avg_te_nll)
            meta_te_mse_results.append(avg_te_mse)
            print(avg_te_nll)
        meta_tr_arr, meta_te_nll_arr, meta_te_mse_arr = np.array(meta_tr_results), \
            np.array(meta_te_nll_results), np.array(meta_te_mse_results)
        np.savetxt(os.path.join(save_path, 'tr_loss_list.csv'), meta_tr_arr)
        np.savetxt(os.path.join(save_path, 'te_nll_list.csv'), meta_te_nll_arr)
        np.savetxt(os.path.join(save_path, 'te_mse_list.csv'), meta_te_mse_arr)
        if epoch%5 == 0:
            meta_net.his_nll = None
        early_stopping(avg_val_nll, meta_net, epoch)
        if early_stopping.early_stop:
            break
        
    return meta_tr_arr, meta_te_nll_arr, meta_te_mse_arr

    
