import os
import argparse
import sys; sys.path.append("./ANODE") # import hack
from tqdm import tqdm
import torch
import torch.nn.functional as F
from models.model import OurModel, Simulator
from models.conv_models import ConvODENet, MNISTConvODENet
from models.mlp_model import ODENet
from omegaconf import OmegaConf
import plotly.express as px
from utils import *
import wandb
from torch.func import vmap, jacrev, jacfwd, functional_call
import time
import copy
from diffusion_utils import *

def learnable(z0, z1, t):
    return None

def linear(z0, z1, t):
    return (1-t) * z0 + t * z1

def half_circle(z0, z1, t):
    ori_shape = z0.shape
    if len(z0.shape) != 2:
        if len(z0.shape) > 2:
            z0 = z0.view(z0.shape[0], -1)
        if len(z0.shape) < 2:
            z0 = z0.view(1, -1)
    if len(z1.shape) != 2:
        if len(z1.shape) > 2:
            z1 = z1.view(z1.shape[0], -1)
        if len(z1.shape) < 2:
            z1 = z1.view(1, -1)
    if len(t.shape) != 2:
        if len(t.shape) > 2:
            t = t.view(t.shape[0], -1)
        if len(t.shape) < 2:
            t = t.view(1, -1)

    degree = t * np.pi
    origin = (z0 + z1) / 2
    radius = (z0 - z1) / 2
    radius_normalized = radius / torch.norm(radius, dim=-1, keepdim=True)
    # choose any vector
    vec = (z0 + z1) / 2
    # Calculate the projection of vec onto target
    ortho = (vec - (vec * radius_normalized).sum(dim=-1, keepdim=True) * radius_normalized)
    radius_ortho = ortho / torch.norm(ortho, dim=-1, keepdim=True) * torch.norm(radius, dim=-1, keepdim=True) # make it same length as radius
    ret = origin + radius * torch.cos(degree) + radius_ortho * torch.sin(degree)
    return ret.reshape(ori_shape)


# def lin_cos(z0, z1, t):
#     degree = t * np.pi / 2.
#     return torch.cos(degree) * z0 + (1-torch.cos(degree)) * z1

def lin_sin(z0, z1, t):
    degree = t * np.pi / 2.
    return (1 - torch.sin(degree)) * z0 + torch.sin(degree) * z1


def cos(z0, z1, t):
    degree = t * np.pi / 2.
    return torch.cos(degree) * z0 + torch.sin(degree) * z1

def inv_cos(z0, z1, t):
    degree = t * np.pi / 2.
    return (1-torch.cos(degree)) * z0 + (1-torch.sin(degree)) * z1

def vp_ode(z0, z1, t):
    alpha = torch.exp(-0.25 * 19.9 * (1-t)**2 - 0.5 * 0.1 * (1-t))
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

def const_vp_ode(z0, z1, t):
    alpha = t
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

@torch.inference_mode()
def test_metric(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False, metric_key='accuracy', label_scaler=None):
    if metric_key == 'accuracy':
        metric, latent_mse, data_mse = test_accuracy(net, test_dataloader, method=method, num_timesteps=num_timesteps, return_mse=return_mse)
    elif metric_key == 'rmse':
        metric, latent_mse, data_mse = test_rmse(net, test_dataloader, method=method, num_timesteps=num_timesteps, return_mse=return_mse,
                                               label_scaler=label_scaler)
    if return_mse:
        return metric, latent_mse, data_mse
    return metric

@torch.inference_mode()
def test_accuracy(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False):
    net.eval()
    correct = 0
    count = 0
    latent_mse = data_mse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure metric'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            # TODO: RUN DDPM 1000 step.
            traj = net.ddpm_sample(X, Y)
            pred = traj[-1]
        else:
            traj = net.ddim_sample(X, Y, num_timesteps-1)
            pred = traj[-1]
        correct += (pred.argmax(dim=-1) == Y.argmax(dim=-1)).float().sum().item()
        count += Y.size(0)
    acc = correct / count
    return acc, latent_mse, data_mse

def test_rmse(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False, label_scaler=None):
    net.eval()
    count = 0
    latent_mse = data_mse = 0
    rmse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure rmse'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            # TODO: RUN DDPM 1000 step.
            traj = net.ddpm_sample(X, Y)
            pred = traj[-1]
        else:
            traj = net.ddim_sample(X, Y, num_timesteps-1)
            pred = traj[-1]
        count += Y.size(0)
        if label_scaler is not None:
            Y_unnorm = label_scaler.inverse_transform(Y.cpu().numpy())
            pred_unnorm = label_scaler.inverse_transform(pred.cpu().numpy())
            rmse += np.mean((Y_unnorm - pred_unnorm)**2) * Y.size(0)
        else:
            rmse += F.mse_loss(pred, Y).item() * Y.size(0)

    #TODO: compute rmse
    rmse /= count
    rmse = rmse ** 0.5
    return rmse, latent_mse, data_mse


def test_norm_avg_reg(net, val_dataloader=None):
    if val_dataloader is None:
        return 0, 0
    net.eval()
    # z1 norm avg
    z1_norm_avg = 0
    # z0 norm avg
    z0_norm_avg = 0
    count = 0
    for i, (X, Y) in tqdm(enumerate(val_dataloader), leave=False, total=len(val_dataloader), desc='Measure Norm Avg'):
        X = X.cuda()
        z0 = net.in_projection(X)
        z1 = net.label_projection(Y.cuda())
        z0_norm_avg += norm_avg(z0) * X.size(0)
        z1_norm_avg += norm_avg(z1) * X.size(0)
        count += X.size(0)
    z0_norm_avg /= count
    z1_norm_avg /= count
    return z0_norm_avg, z1_norm_avg


def test_norm_avg_cls(net, val_dataloader=None, num_classes=10):
    if val_dataloader is None:
        return 0, 0
    net.eval()
    # z1 norm avg
    all_labels = F.one_hot(torch.arange(num_classes)).float().cuda()
    z1 = net.label_projection(all_labels)
    z1_norm_avg = norm_avg(z1)
    # z0 norm avg
    z0_norm_avg = 0
    count = 0
    for i, (X, _) in tqdm(enumerate(val_dataloader), leave=False, total=len(val_dataloader), desc='Measure Norm Avg'):
        X = X.cuda()
        z0 = net.in_projection(X)
        z0_norm_avg += norm_avg(z0) * X.size(0)
        count += X.size(0)
    z0_norm_avg /= count
    return z0_norm_avg, z1_norm_avg


class Trainer(object):
    def __init__(self, net, total_steps, optimizer='adam', scheduler='none', lr=1e-3, wd=[0., 0., 0.], loss_start_epoch=[0, 0, 0], lambdas=[1.0, 1.0, 0.0],
          task_criterion=torch.nn.MSELoss(), force_zero_prob=0., test_every=2, 
          label_ae_noise=0., method='ours', f_sg_target=False, augment_t=1, t_transform='identity', label_ae_mse=True,
          task_dec=False, f_jac_clamp=(-1, -1), g_jac_clamp=(-1,-1), train_alter=False, alter_order=['fgh'], alter_epoch=[1],
          fgh_lr=None, sync_t=False, label_flow_noise=0., label_flow_noise_0=0., invert_transform_t=False, dynamics=linear,
          dataset='mnist', label_scaler=None, save_every=24, patience=-1, steer=0, ema=0):
        
        self.net = net
        self.wd = wd
        self.lr = lr
        self.total_steps = total_steps
        self.loss_start_epoch = loss_start_epoch
        self.lambdas = lambdas
        self.task_criterion = task_criterion
        self.force_zero_prob = force_zero_prob
        self.test_every = test_every
        self.label_ae_noise = label_ae_noise
        self.method = method
        assert method == 'card'
        self.f_sg_target = f_sg_target
        self.augment_t = augment_t
        self.t_transform = t_transform
        self.label_ae_mse = label_ae_mse
        self.task_dec = task_dec
        self.f_jac_clamp = f_jac_clamp
        self.g_jac_clamp = g_jac_clamp
        self.train_alter = train_alter
        self.alter_order = alter_order
        self.alter_epoch = alter_epoch
        self.fgh_lr = fgh_lr
        self.sync_t = sync_t
        self.label_flow_noise = label_flow_noise
        self.label_flow_noise_0 = label_flow_noise_0
        self.invert_transform_t = invert_transform_t
        self.dynamics = dynamics
        self.dyn_v = vmap(jacfwd(dynamics, argnums=2))
        self.ckpt_dir = wandb.run.dir.replace('wandb', 'ckpts') # os.path.join(wandb.run.dir, 'files')
        os.makedirs(self.ckpt_dir, exist_ok=True)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.label_scaler = label_scaler
        self.dataset = dataset
        self.save_every = save_every
        self.time_threshold = save_every * 3600
        self.early_stop_count = 0
        self.early_stopping_patience = patience
        self.steer = steer

        self.ema_alpha = ema
        self.ema = copy.deepcopy(net.state_dict())
        self.ema_net = copy.deepcopy(net)

        for p in self.ema_net.parameters():
            p.requires_grad_(False)

        if dataset in ['mnist', 'cifar10', 'svhn']:
            self.metric_type = 'accuracy'
        else:
            self.metric_type = 'rmse'

        os.makedirs(self.ckpt_dir, exist_ok=True)

        self.configure_optimizer()
    
    def ema_update(self, alpha=0.999):
        for n, p in self.net.named_parameters():
            self.ema[n] = alpha * self.ema[n] + (1 - alpha) * p.data
    
    def ema_restore(self):
        for n, p in self.ema_net.named_parameters():
            p.data = self.ema[n]
    
    def configure_optimizer(self):
        if self.fgh_lr is None:
            self.fgh_lr = [self.lr, self.lr, self.lr]

        if self.optimizer == 'adam':
            opt = torch.optim.AdamW
        elif self.optimizer == 'radam':
            opt = torch.optim.RAdam
        else:
            raise ValueError(f'Optimizer {self.optimizer} not supported')

        if not self.train_alter:
            assert self.alter_order == ['fgh'] and self.alter_epoch == [1]

        optimizers = []
        for target in self.alter_order:
            assert type(target) == str, 'alter_order should be a list of strings'
            params = []
            if 'f' in target:
                params.append({'params': self.net.in_projection.parameters(), 'lr': self.fgh_lr[0], 'weight_decay': self.wd})
            if 'g' in target:
                params.append({'params': self.net.out_projection.parameters(), 'lr': self.fgh_lr[1], 'weight_decay': self.wd})
                params.append({'params': self.net.label_projection.parameters(), 'lr': self.fgh_lr[1], 'weight_decay': self.wd})
            if 'h' in target:
                params.append({'params': self.net.odefunc.parameters(), 'lr': self.fgh_lr[2], 'weight_decay': self.wd})
            
            optimizer = opt(params, lr=self.lr)
            optimizers.append(optimizer)
        assert len(optimizers) == len(self.alter_epoch), 'optimizers and train_alter_epoch should have same length'

        if self.scheduler == 'none':
            scheduler = None
        elif self.scheduler == 'cos':
            assert len(optimizers) == 1, 'Cosine annealing scheduler only supports single optimizer'
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[0], T_max=self.total_steps, eta_min=0)
            # TODO:
        elif self.scheduler == 'step':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [5000, 10000], gamma=0.1)
        assert len(optimizers) == 1
        
        self.optimizer = optimizers[0]
        self.scheduler = scheduler

    @torch.inference_mode()
    def evaluate(self, test_dataloader, val_dataloader, subset_loader, current_step, best_test_metric=0, best_val_metric=0):
        metric_dict = self.test(test_dataloader)
        valtric_dict = self.test(val_dataloader, metric_key='val')
        if subset_loader is not None:
            subset_dict = self.test(subset_loader, metric_key='train_subset')
            metric_dict.update(subset_dict)
        metric_dict.update(valtric_dict)
        metric_val = metric_dict[f'val/{self.metric_type}_dopri']
        metric1 = metric_dict[f'test/{self.metric_type}_1']
        metric2 = metric_dict[f'test/{self.metric_type}_2']
        metricinf =  metric_dict[f'test/{self.metric_type}_dopri']

        if self.metric_type == 'accuracy':
            condition = metricinf > best_test_metric
            val_condition = metric_val > best_val_metric
        else:
            condition = metricinf < best_test_metric
            val_condition = metric_val < best_val_metric

        if condition:
            best_test_metric = metricinf
            torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'best.ckpt'))
            torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'best_ema.ckpt'))
        torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'last.ckpt'))
        torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'last_ema.ckpt'))
        
        if val_condition:
            best_val_metric = metric_val
            torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'best_val.ckpt'))
            torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'best_val_ema.ckpt'))
            best_val_log = {}
            for k, v in metric_dict.items():
                if 'val' in k:
                    continue
                if 'test' in k:
                    best_val_log[k.replace('test', 'test_on_best_val')] = v
            wandb.log(best_val_log, commit=False)
            self.early_stop_count = 0
        else:
            # TODO: 
            self.early_stop_count += 1

        if self.metric_type == 'accuracy':
            metric_dict['test/best_acc'] = best_test_metric
            metric_dict['val/best_acc'] = best_val_metric
        else:
            metric_dict['test/best_rmse'] = best_test_metric
            metric_dict['val/best_rmse'] = best_val_metric
        print(f'Step {current_step}/{self.total_steps}, Val {self.metric_type} {metric_val:.4f}, Test {self.metric_type} 1/2/inf {metric1:.4f}/{metric2:.4f}/{metricinf:.4f}')
        wandb.log(metric_dict, commit=False)
        return best_test_metric, best_val_metric

    def compute_jacobian(self, x, y, mode='f'):
        # compute jacobian approx. from x to y
        if mode == 'f':
            lam_min, lam_max = self.f_jac_clamp
        elif mode == 'g':
            lam_min, lam_max = self.g_jac_clamp
        else:
            assert 0

        delta_x = x[1:] - x[:-1]
        delta_y = y[1:] - y[:-1]
        Q = delta_y.reshape(delta_y.size(0), -1).norm(dim=-1) / delta_x.reshape(delta_x.size(0), -1).norm(dim=-1)

        # pass if delta_z is zero
        mask = Q > 1e-9
        Q = Q[mask]

        if mode == 'f':
            wandb.log({"train/f_jac_Q": Q.mean().item()}, commit=False)
        else:
            wandb.log({"train/g_jac_Q": Q.mean().item()}, commit=False)

        jac_clamp_loss = torch.zeros(1).cuda()
        if lam_min >= 0 and lam_max >= 0:
            target_min = torch.ones_like(Q) * lam_min
            target_max = torch.ones_like(Q) * lam_max
            L_min = F.mse_loss(torch.minimum(Q, target_min), target_min, reduction='mean')
            L_max = F.mse_loss(torch.maximum(Q, target_max), target_max, reduction='mean')
            jac_clamp_loss = L_min + L_max
        return jac_clamp_loss
    
    def sample_timestep(self, z0, device):
        t = torch.randint(low=0, high=self.net.T, size=(z0.size(0), ), device=device)
        return t

    def fit(self, train_dataloader, val_dataloader, test_dataloader, subset_loader=None,):
        self.current_step = 0
        best_test_metric = 0 if self.metric_type == 'accuracy' else 1e9
        best_val_metric = 0 if self.metric_type == 'accuracy' else 1e9
        epoch = 0
        start_time = time.time()
        next_thresh = start_time + self.time_threshold
        hour_count = 0
        self.early_stop_count = 0

        pbar = tqdm(total=self.total_steps, desc='Training')
        flow_loss_timestep_bin = torch.zeros(11) if self.method == 'ours' else None
        timestep_bin_count = torch.zeros(11) if self.method == 'ours' else None

        while True:
            logs = {}
            count = 0
            for i, (X, Y) in enumerate(train_dataloader):
                ### single epoch training
                log, flow_loss_timestep_bin, timestep_bin_count = self.training_step(X, Y, flow_loss_timestep_bin, timestep_bin_count)
                pbar.update(1)

                # accumulate log
                for k, v in log.items():
                    if k in logs:
                        logs[k] += v
                    else:
                        logs[k] = v
                count += 1                

                cur_log = {k: v / count for k, v in logs.items()}
                pbar.set_description(
                    ', '.join([f'{k}: {v:.3e}' for k, v in cur_log.items()])
                    )

                ### evaluate on test_every
                self.current_step += 1

                if self.current_step % self.test_every == 0:
                    best_test_metric, best_val_metric = self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
                    if self.early_stop_count >= self.early_stopping_patience and self.early_stopping_patience > 0:
                        print(f'Early stopping at step {self.current_step}')
                        return self.net
                
                if time.time() > next_thresh:
                    hour_count += 1
                    best_test_metric, best_val_metric = self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
                    next_thresh = time.time() + self.time_threshold
                    torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'{hour_count * self.save_every}hr.ckpt'))
                    torch.save(self.optimizer.state_dict(), os.path.join(self.ckpt_dir, f'{hour_count * self.save_every}hr_opt.ckpt'))                        

                if self.current_step >= self.total_steps:
                    break
            
            if self.current_step >= self.total_steps:
                break

            for k, v in logs.items():
                logs[k] /= count

            # per-epoch logging
            wandb.log({k+'_epoch': v for k, v in logs.items()}, commit=False)
            wandb.log({'epoch': epoch}, commit=False)
            epoch += 1

        # save last checkpoint and finish
        self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
        return self.net
    
    def training_step(self, X, Y, flow_loss_timestep_bin=None, timestep_bin_count=None):
        self.net.train()
        device = self.net.device
        X, Y = X.to(device), Y.to(device)
        # sample noise

        z = torch.randn_like(Y)

        # sampling timestep
        t = self.sample_timestep(z, device)
        
        # q_sample y
        yt = self.net.q_sample(Y, t, z)

        # Run model
        self.net.zero_grad()
        pred = self.net(X, yt, torch.zeros_like(yt), t)
        loss = F.mse_loss(pred, z)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
        self.optimizer.step()
        if self.scheduler is not None:
            self.scheduler.step()

        # logging
        wandb.log({
            'learning_rate': self.optimizer.param_groups[0]['lr'],
            'train/loss': loss.item(),
        })

        log = {
            'loss': loss.item(),
        }
        self.ema_update(alpha=self.ema_alpha)
        return log, flow_loss_timestep_bin, timestep_bin_count


    def test(self, test_dataloader, metric_key='test', do_dopri=True):
        '''
        Returns a dict of metrics.
        - val/metric (if val_dataloader is not None)
        - test/metric_1
        - test/metric_2
        - test/metric_dopri
        - test/straightness
        - test/latent_mse
        - test/data_mse
        '''
        ret = {}
        self.ema_restore()
        net = self.ema_net if self.ema_alpha > 0 else self.net
        metric1 = test_metric(net, test_dataloader, method='euler', num_timesteps=1+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric2 = test_metric(net, test_dataloader, method='euler', num_timesteps=2+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric10 = test_metric(net, test_dataloader, method='euler', num_timesteps=10+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric20 = test_metric(net, test_dataloader, method='euler', num_timesteps=20+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        if do_dopri:
            metricinf, latent_mse, data_mse = test_metric(net, test_dataloader, method='dopri5',
                                                          return_mse=True, metric_key=self.metric_type, label_scaler=self.label_scaler)
            dopri_nfe = 1000
        test_norm_avg = test_norm_avg_cls if self.metric_type == 'accuracy' else test_norm_avg_reg

        ret[f'{metric_key}/{self.metric_type}_1'] = metric1
        ret[f'{metric_key}/{self.metric_type}_2'] = metric2
        ret[f'{metric_key}/{self.metric_type}_10'] = metric10
        ret[f'{metric_key}/{self.metric_type}_20'] = metric20
        if do_dopri:
            ret[f'{metric_key}/{self.metric_type}_dopri'] = metricinf
            ret[f'{metric_key}/dopri_nfe'] = dopri_nfe
        return ret