import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from betty.engine import Engine
from betty.problems import ImplicitProblem
from betty.configs import Config, EngineConfig
from networks import EstimatorNet, RegressorNet
from scipy.stats import norm
from torchvision.utils import save_image
import os
from tqdm import tqdm
import pandas as pd
from utils import Logger

def pin_ball_loss(eps, tau):
    loss = (eps >= 0).float() * tau * eps + (eps < 0).float() * (tau - 1) * eps
    return loss

class MyEngine(Engine):
    def __init__(self, problems, args, config=None, dependencies=None, env=None):
        super().__init__(problems, config, dependencies, env)
        self.args = args
        self.device = args.device
        self.logger = Logger(os.path.join(args.run_dir, 'log.txt'))

    def save_model(self, fp):
        torch.save({'estimator': self.problems[0].module.state_dict(),
                    'regressor': self.problems[1].module.state_dict(),
                    'estimator_optimizer': self.problems[0].optimizer.state_dict(),
                    'regressor_optimizer': self.problems[1].optimizer.state_dict(),
                    }, os.path.join(self.args.run_dir, fp))

    def load_model(self, fp):
        if os.path.exists(fp):
            print('loading model from ', fp)
            ckpt = torch.load(fp)
        else:
            print('loading model from ', os.path.join(self.args.run_dir, fp))
            ckpt = torch.load(os.path.join(self.args.run_dir, fp))
        self.problems[0].module.load_state_dict(ckpt['estimator'])
        self.problems[1].module.load_state_dict(ckpt['regressor'])
        self.problems[0].optimizer.load_state_dict(ckpt['estimator_optimizer'])
        self.problems[1].optimizer.load_state_dict(ckpt['regressor_optimizer'])

    @torch.no_grad()
    def test(self, loader, tag=None, it=None):
        cf_true = []
        cf_pred = []
        cf_treats = []
        f_trues = []
        f_preds = []
        for batch in tqdm(loader):
            covariate = batch['covariate'].to(self.device)
            treatment = batch['treatment'].to(self.device)
            outcome = batch['outcome'].to(self.device)
            cf_treatment = batch['cf_treatment'].to(self.device)
            cf_outcome = batch['cf_outcome'].to(self.device)
            cf_treats.append(cf_treatment)
            quantile = self.problems[0].module(covariate, treatment, outcome)
            pred = self.problems[1].module(covariate, cf_treatment, quantile)
            f_pred = self.problems[1].module(covariate, treatment, quantile)
            cf_true.append(cf_outcome)
            cf_pred.append(pred)
            f_trues.append(outcome)
            f_preds.append(f_pred)
        cf_true = torch.cat(cf_true)
        cf_pred = torch.cat(cf_pred)
        f_trues = torch.cat(f_trues)
        f_preds = torch.cat(f_preds)
        cf_treats = torch.cat(cf_treats).cpu().numpy()
        assert cf_true.size()==cf_pred.size()
        mse = F.mse_loss(f_trues, f_preds)
        recon_rmse = np.sqrt(mse.item())
        print('Recon Performance: %.4f  MSE, RMSE: %.4f ' % (recon_rmse.item(), recon_rmse), len(cf_true))
        mse = ((cf_true-cf_pred)**2).view(len(cf_true),-1).sum(dim=1).mean()
        rmse = np.sqrt(mse.item())




        print('Counterfactual Inference Performance: %.4f  MSE, RMSE: %.4f ' % (mse.item(), rmse), len(cf_true))
        if tag is not None and len(cf_true.size())==2:
            plt.scatter(cf_treats, cf_true.cpu().numpy().reshape(-1), label='true')
            plt.scatter(cf_treats, cf_pred.cpu().numpy().reshape(-1), label='pred')
            plt.legend()
            plt.savefig(os.path.join(self.args.run_dir, 'plot%s.png' % (it)))
            plt.close()
            df = pd.DataFrame({'y_cf': cf_true.cpu().numpy().reshape(-1), 'y_cf_hat': cf_pred.cpu().numpy().reshape(-1)})
            df.to_csv(os.path.join(self.args.run_dir, '%s.csv' % (tag)), index=False)
            print('saved to ', os.path.join(self.args.run_dir, '%s.csv' % (tag)))
        return rmse, recon_rmse

    @torch.no_grad()
    def check_progress(self, train_batch, batch, it):
        num_display = 20
        train_quantile = self.problems[0].module(train_batch['covariate'].to(self.device), train_batch['treatment'].to(self.device), train_batch['outcome'].to(self.device))
        cov = batch['covariate'][:num_display]
        is_image = len(cov.size())==4
        if is_image:
            save_image(torch.cat([train_batch['covariate'][:num_display], train_batch['outcome'][:num_display]], dim=0),
                     os.path.join(self.args.run_dir, 'train_%06d.png' % (it)),
                     nrow=20
            )
        covariate = batch['covariate'][:num_display].to(self.device)
        treatment = batch['treatment'][:num_display].to(self.device)
        outcome = batch['outcome'][:num_display].to(self.device)
        cf_treatment = batch['cf_treatment'][:num_display].to(self.device)
        cf_outcome = batch['cf_outcome'][:num_display].to(self.device)
        quantile = self.problems[0].module(covariate, treatment, outcome)
        print(train_quantile.view(-1)[:10], ' >>> train ')
        print(quantile.view(-1)[:10], ' >>>> test ')
        pred = self.problems[1].module(covariate, cf_treatment, quantile)

        if len(covariate.size())==4:
            all_images = torch.cat([covariate[:num_display], outcome[:num_display], pred[:num_display], cf_outcome[:num_display]], dim=0)
            save_image(all_images, os.path.join(self.args.run_dir, 'progress_%06d.png' % (it)),
                   nrow=20)
            new_k = []
            print(new_k)
            for i in new_k:
                save_image(covariate[i], os.path.join(self.args.run_dir, 'cov_%d_color0.9.png' % (i)))
                save_image(outcome[i], os.path.join(self.args.run_dir, 'out_%d_color0.9.png' % (i)))
                save_image(cf_outcome[i], os.path.join(self.args.run_dir, 'cf_%d_color0.9.png' % (i)))
                save_image(pred[i], os.path.join(self.args.run_dir, 'pred_%d_color0.9.png' % (i)))


    def run(self, train_loader, loader):
        """
		Execute multilevel optimization by running gradient descent for leaf problems.
		"""
        self.train()
        fixed_batch = next(iter(loader))
        fixed_train_batch = next(iter(train_loader))
        pbar = tqdm(range(1, self.train_iters + 1))
        run_dir = self.args.run_dir
        args = self.args
        os.makedirs(run_dir, exist_ok=True)

        best_recon = 1e8
        for it in pbar:
            self.global_step += 1
            is_best = False
            self.train_step()
            if it % args.decay_every == 0 and it>0:
                for param_group in self.problems[0].optimizer.param_groups:
                    param_group['lr'] = max(1e-8, param_group['lr'] * args.decay_rate)
                for param_group in self.problems[1].optimizer.param_groups:
                    param_group['lr'] = max(1e-8, param_group['lr'] * args.decay_rate)
                print('>>>lr ', param_group['lr'], self.problems[0].optimizer.param_groups[0]['lr'])
            if it % args.check_every == 0:
                self.check_progress(fixed_train_batch, fixed_batch, it)
            if it % args.test_every == 0:
                print('===========', it)
                train_rmse, recon_train_rmse = self.test(train_loader, it=it, tag='%s'%it)
                print('===========', it)
                test_rmse, recon_test_rmse = self.test(loader, it=it, tag='%s'%it)
                print('===========', it)
                self.logger.write({'train_rmse': train_rmse, 'test_rmse': test_rmse}, it)
                if recon_test_rmse < best_recon:
                    best_recon = recon_train_rmse
                    print('>>>> Best Model saved <<<<')
                    self.save_model('model_best.pt')
                    test_rmse, recon_test_rmse = self.test(loader, it='best', tag='best')
                    print('===========', it)
                    self.logger.write({'best_train_rmse': train_rmse, 'best_test_rmse': test_rmse}, it)
                    self.check_progress(fixed_train_batch, fixed_batch, it)

            if it % args.save_every == 0:
                self.save_model('model_%06d.pt' % (it))
        self.test(train_loader, tag='train_exp%03d' % self.args.exp_num)
        self.test(loader, tag='test_exp%03d' % self.args.exp_num)



class QuantileEstimator(ImplicitProblem):
    def training_step(self, batch):
        covariate, treatment, outcome = batch['covariate'].to(self.device), batch['treatment'].to(self.device), batch['outcome'].to(self.device)
        quantile = self.module(covariate, treatment, outcome)
        pred = self.inner.module(covariate, treatment, quantile)
        loss = F.mse_loss(pred, outcome)
        return loss

class QuantileRegressor(ImplicitProblem):
    def training_step(self, batch):
        covariate, treatment, outcome = batch['covariate'].to(self.device), batch['treatment'].to(self.device), batch['outcome'].to(self.device)
        test_batch = next(self.test_iterator)
        test_cov, test_treat, test_out = test_batch['covariate'].to(self.device), test_batch['treatment'].to(self.device), test_batch['outcome'].to(self.device)
        quantile = self.outer.module(test_cov, test_treat, test_out)

        repeat_dims = [test_cov.shape[0]] + [1] * (len(covariate.size()) - 1)
        flat_covariate = covariate.repeat(*repeat_dims)
        repeat_dims = [test_cov.shape[0]] + [1] * (len(treatment.size()) - 1)
        flat_treatment = treatment.repeat(*repeat_dims)
        repeat_dims = [test_cov.shape[0]] + [1] * (len(outcome.size()) - 1)
        flat_outcome = outcome.repeat(*repeat_dims)
        flat_quantile = torch.repeat_interleave(quantile, repeats=covariate.shape[0], dim=0)

        flat_pred = self.module(flat_covariate, flat_treatment, flat_quantile)
        flat_eps = (flat_outcome - flat_pred).view(len(flat_outcome), -1)
        loss = pin_ball_loss(flat_eps, flat_quantile)
        self.global_step += 1
        return torch.mean(loss)
