import argparse
#import wandb
import random
import torch
from huggingface_download import apply_patch; apply_patch()
try:
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
except Exception:
    print("Not npu case")
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
from torchvision.utils import make_grid, save_image
import torch.nn as nn
import copy
import os, gc
import pickle
import torch.nn.functional as F
from cleanfid import fid
import torch
from absl import app, flags
from torchvision import datasets, transforms
from tqdm import trange, tqdm
from utils_cifar import ema, infiniteloop, TensorBoardWriter, WandBWriter
import numpy as np
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
    pad_t_like_x
)
from torchcfm.models.unet.unet import UNetModelWrapper


####Parameters


#Set seed
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)


is_log = True #log to wandb or not
COND = True #conditional or unconditional
parallel = False
logger_name = "tensorboard"
    
# UNet
num_channel=  128

# Training
lr= 3e-5  #1e-5 for finetuning

parser = argparse.ArgumentParser(description="alpha gamma parser")
parser.add_argument('--alpha', type=float, default=1.0, help='alpha')
parser.add_argument('--gamma', type=float, default=1.0, help='gamma')
parser.add_argument('--gen_coef', type=float, default=None, help='gen_coef')
parser.add_argument('--disc_coef', type=float, default=None, help='disc_coef')
parser.add_argument('--with_gan_loss', action="store_true", help='with_gan_loss')
parser.add_argument('--use_u_star', action="store_true")
parser.add_argument('--ckpt_path', type=str, default=None)



args = parser.parse_args()

alpha = args.alpha #0.94 - 1.00
gamma = args.gamma #0.94 - 1.00 with a step 0.02
use_u_star = args.use_u_star

with_gan_loss = args.with_gan_loss
gen_coef =  args.gen_coef # 0.3 1.0 5.0 25.0  the best set is (5.0, 15.0)
disc_coef = args.disc_coef # 1.0 3.0 15.0 75.0 
ckpt_path = args.ckpt_path

if with_gan_loss:
    assert (gen_coef is not None or disc_coef is not None), "Initialize gan coefs!"
else:
    assert (gen_coef is None or disc_coef is None), "If gan coefs are not None add --with_gan_loss!"

if with_gan_loss:
    exp_name = f'alpha{alpha}_gamma{gamma}_use_gan_gen_coef_{gen_coef}_disc_coef_{disc_coef}'
else:
    exp_name = f'alpha{alpha}_gamma{gamma}' 

if ckpt_path is not None:
    exp_name = exp_name + "_continue"


output_dir = "./result_cifar/"
savedir = output_dir + exp_name + "/"
os.makedirs(savedir, exist_ok=True)
logger = TensorBoardWriter(savedir) if logger_name == "tensorboard" else WandBWriter()

#do not change parameters below
grad_clip = 1.0
total_steps = 400001
warmup = 500 
batch_size = 256  
num_workers = 1
ema_decay = 0.999 
adv_step = 6 



# Evaluation

num_gen = 50000
fid_step = 6000
save_image_step = 2000
save_model_step = 6000

# Finetune
finetune = False
finetune_path = 'ft_path'




### Model for GAN with extra head

NUM_CLASSES = 1000
from torchcfm.models.unet.nn import timestep_embedding

class UNetModelWrapperWithHead(UNetModelWrapper):
    def __init__(
        self,
        dim,
        num_channels,
        num_res_blocks,
        channel_mult=None,
        learn_sigma=False,
        class_cond=False,
        num_classes=NUM_CLASSES,
        use_checkpoint=False,
        attention_resolutions="16",
        num_heads=1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        dropout=0,
        resblock_updown=False,
        use_fp16=False,
        use_new_attention_order=False,
    ):
        
        super().__init__(
        dim,
        num_channels,
        num_res_blocks,
        channel_mult,
        learn_sigma,
        class_cond,
        num_classes,
        use_checkpoint,
        attention_resolutions,
        num_heads,
        num_head_channels,
        num_heads_upsample,
        use_scale_shift_norm,
        dropout,
        resblock_updown,
        use_fp16,
        use_new_attention_order,
        )

        
        self.cls_pred_branch = nn.Sequential(
                nn.Conv2d(kernel_size=2, in_channels=256, out_channels=256, stride=2, padding=0),
                nn.GroupNorm(num_groups=32, num_channels=256),
                nn.SiLU(),
                nn.Conv2d(kernel_size=2, in_channels=256, out_channels=256, stride=2, padding=0), 
                nn.GroupNorm(num_groups=32, num_channels=256),
                nn.SiLU(),
                nn.Conv2d(kernel_size=1, in_channels=256, out_channels=1, stride=1, padding=0), 
            ) 
        self.cls_pred_branch.requires_grad_(True)
            
        
    def forward(self, t, x, y=None, *args, **kwargs):
        return super().forward(t, x, y=y)

    def forward_head(self, t, x, y=None, *args, **kwargs):
        timesteps = t
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        while timesteps.dim() > 1:
            print(timesteps.shape)
            timesteps = timesteps[:, 0]
        if timesteps.dim() == 0:
            timesteps = timesteps.repeat(x.shape[0])

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
  
        
        return self.cls_pred_branch(h)
    
def compute_cls_logits(u, x, t, y = None):
    logits = u.forward_head(t, x, y).float()
    return logits
    

def GANloss(u,  x1_gen, x1_data, y = None, generator_turn = True, t = None, x0 = None):
    if t is not None:
        t_padded = pad_t_like_x(t, x0)
        xt_gen =   x1_gen * t_padded + (1 - t_padded) * x0
        xt_data =   x1_data * t_padded + (1 - t_padded) * x0
    else:
        t = torch.ones(x1_gen.shape[0]).type_as(x1_gen)
        xt_gen = x1_gen
        xt_data = x1_data

    if generator_turn:
        pred_realism_on_fake_with_grad = compute_cls_logits(u, xt_gen, t, y)
        return F.softplus(-pred_realism_on_fake_with_grad).mean()
    else:
        pred_realism_on_real = compute_cls_logits(
            u, xt_data.detach(), t, y
        )
        pred_realism_on_fake = compute_cls_logits(
            u, xt_gen.detach(), t, y
        )
        return  -F.softplus(pred_realism_on_fake).mean() - F.softplus(-pred_realism_on_real).mean()

        
###Adv loss


def dist_loss(u, u_star, t, x0, x1_gen, x1_data, y = None, alpha = 1.0, gamma = 1.0, generator_turn = True):
    

    t_padded = pad_t_like_x(t, x0) 
    
    
    xt_gen =   x1_gen * t_padded + (1 - t_padded) * x0
    u_star_xt_gen = u_star(t, xt_gen, y)
    u_xt_gen = u(t, xt_gen, y)

    
    if alpha < 1.0:
        xt_data =  t_padded * x1_data + (1 - t_padded) * x0
        u_star_xt_data = u_star(t, xt_data, y)
        u_xt_data = u(t, xt_data, y)
   
        
    if not generator_turn:
        
        loss = - alpha * torch.sum((u_xt_gen - gamma/(alpha) * (x1_gen - x0))**2)
        
        if alpha < 1.0:
            if use_u_star:
                loss = loss - (1 - alpha) * torch.sum((u_xt_data - (1 - gamma)/( 1 - alpha) * u_star_xt_data)**2)
            else:
                loss = loss - (1 - alpha) * torch.sum((u_xt_data - (1 - gamma)/( 1 - alpha) * (x1_data - x0))**2)
        
    else:
        
        loss = alpha * torch.sum((u_star_xt_gen - gamma/alpha * (x1_gen - x0))**2) - alpha * torch.sum((u_xt_gen - gamma/alpha * (x1_gen - x0))**2)
    
    return  loss/batch_size

    
    
    
###GEN and EVAL functions  

def gen_function(generator, z, y = None):
    return z + generator(torch.zeros(z.shape[0]).type_as(z), z, y)
 

def eval_model(generator):
    def gen_1_img(unused_latent):
        with torch.no_grad():
            
            if COND:
                y = torch.from_numpy(np.repeat(np.arange(10), 10)).to(device)
            else:
                y = None
            
            z = torch.randn(100, 3, 32, 32, device=device)
            images = (gen_function(generator, z, y)).clip(-1, 1)
            images = (images * 127.5 + 128).clip(0, 255).to(torch.uint8)  # .permute(1, 2, 0)
            return images

    generator.eval()
    print("Start computing FID")
    score = fid.compute_fid(
        gen=gen_1_img,
        dataset_name="cifar10",
        batch_size=100,
        dataset_res=32,
        num_gen=num_gen,
        dataset_split="train",
        mode="legacy_tensorflow",
        use_dataparallel=False, # to avoid error for NPU
    )
    print()
    print("FID has been computed")
    print()
    print("FID: ", score)
    generator.train()
    return score



def generate_samples(generator, savedir, step, net_="normal",  log = False):
    generator.eval()
    with torch.no_grad():
        
        if COND:
            y = torch.from_numpy(np.repeat(np.arange(10), 10)).to(device)
        else:
            y = None
         
        z = torch.randn(100, 3, 32, 32, device=device)
        images =(gen_function(generator, z, y)).clip(-1,1)

        images = images / 2 + 0.5
        
    img_path = savedir + f"{net_}_generated_gen_images_step_{step}.png"
    save_image(images, img_path, nrow=10)
    if log:
        # wandb.log({f"Gen pics_{net_}": wandb.Image(img_path, caption=f"{net_}_generated_gen_images_step_{step}")})
        logger.add_image(step, f"{net_}_generated_gen_images_step_{step}", make_grid(images, nrow=10))
        
    generator.train()

    
    
    
#LOAD REAL DATA    
trans = transforms.Compose(
            [  
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )


dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=trans
        
    )

dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size = 10*batch_size//10,
        shuffle=True,
        num_workers = num_workers,
        drop_last=True,
    )
    
datalooper = infiniteloop(dataloader)



###model for dist

net_model_for_dist = UNetModelWrapperWithHead(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.0,
        class_cond  = COND
    ).to(
        device
    )  # new dropout + bs of 128

# Load the model conditional

if COND:
    
    PATH = 'cifar10_cond_field.pt'
    print("path: ", PATH)
    checkpoint = torch.load(PATH, map_location=device)
    net_model_for_dist.load_state_dict(checkpoint["ema_net"], strict = False)
    net_model_for_dist.eval()

else:
# Load the model UNconditional

    PATH = 'cfm_cifar10_weights_step_400000.pt'
    print("path: ", PATH)
    checkpoint = torch.load(PATH, map_location=device)
    state_dict = checkpoint["ema_model"] #_model
    try:
        net_model_for_dist.load_state_dict(state_dict, strict = False)
    except RuntimeError:
        from collections import OrderedDict

        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            new_state_dict[k[7:]] = v
        net_model_for_dist.load_state_dict(new_state_dict, strict = False)
    net_model_for_dist.eval()
    


    
#### Init all models

u = UNetModelWrapperWithHead(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.0,
        class_cond = COND
    ).to(device)  




generator = copy.deepcopy(net_model_for_dist)
    
ema_gen = copy.deepcopy(generator)



optim_gen = torch.optim.Adam(generator.parameters(),betas=(0.0, 0.999), lr=lr) #betas=(0, 0.999) Adam

optim_u = torch.optim.Adam(u.parameters(), betas=(0.0, 0.999), maximize = True, lr=lr) 


if finetune:

    checkpoint = torch.load(finetune_path, map_location=device)
    generator.load_state_dict(checkpoint["ema_gen"], strict = False)
    u = copy.deepcopy(net_model_for_dist)
    ema_gen.load_state_dict(checkpoint["ema_gen"], strict = False)




def warmup_lr(step):
    return min(step, warmup) / warmup
sched_gen = torch.optim.lr_scheduler.LambdaLR(optim_gen, lr_lambda=warmup_lr)
sched_u = torch.optim.lr_scheduler.LambdaLR(optim_u, lr_lambda=warmup_lr)


init_step = 0
if ckpt_path is not None:
    checkpoint = torch.load(ckpt_path, map_location=device)
    generator.load_state_dict(checkpoint["gen"], strict = False)
    ema_gen.load_state_dict(checkpoint["ema_gen"], strict = False)    
    
    u.load_state_dict(checkpoint["u"], strict = False)
    
    optim_gen.load_state_dict(checkpoint["optim_gen"]) 
    optim_u.load_state_dict(checkpoint["optim_u"])
    
    sched_gen.load_state_dict(checkpoint["sched"])
    sched_u.load_state_dict(checkpoint["sched"])
    init_step = checkpoint["step"]

##TRAINING

# if is_log:
#     PROJECT_NAME = '.'
#     OUTPUT_PATH = '.' 
#     wandb.init(project= PROJECT_NAME, name= OUTPUT_PATH )
fids = []
ema_fids = []
ema_fids_std = []

for param in net_model_for_dist.parameters():
    param.requires_grad = False

with trange(init_step, total_steps, dynamic_ncols=True) as pbar:
        
        for step in pbar:
            gc.collect()
            x1_data, y = next(datalooper)
            x1_data = x1_data.to(device)
            
            if COND:
                y = y.to(device)
            else:
                y = None
                
            x0 = torch.randn_like(x1_data)
            t = torch.rand(x0.shape[0]).type_as(x0)
            z = torch.randn_like(x0)
            
            t_gan = 1.0 - 0.2 * torch.rand(x0.shape[0]).type_as(x0)   #the best set - 0.8 - 1.0         
    

            if step % adv_step == adv_step - 1:
                    u.eval()
                    

                    optim_gen.zero_grad()
                    
                    x1_gen = gen_function(generator, z, y)
                    loss = dist_loss(u, net_model_for_dist, t, x0, x1_gen, x1_data, y, alpha, gamma, True)
                    
                    if with_gan_loss:
                        gan_loss_gen = GANloss(u,  x1_gen, x1_data, y, True, t_gan, x0) #t,x0
                        loss =  loss + gen_coef * gan_loss_gen 
                        if is_log:
                            # wandb.log({"GAN loss gen": gan_loss_gen.detach().cpu().numpy(), "loss": loss.detach().cpu().numpy()})
                            logger.add_scalar(step, "GAN loss gen", gan_loss_gen.detach().cpu().numpy())
                            logger.add_scalar(step, "loss gen", loss.detach().cpu().numpy())
                    else:
                        if is_log:
                            logger.add_scalar(step, "loss gen", loss.detach().cpu().numpy())
                            
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(generator.parameters(), grad_clip)  

                    optim_gen.step()
                    sched_gen.step()
                    
                   
                    u.train()

                    ema(generator, ema_gen, ema_decay) 
            else:
                    generator.eval()
                    optim_u.zero_grad()
                    x1_gen = gen_function(generator, z, y)
                    x1_gen = x1_gen.detach()

                    loss = dist_loss(u, net_model_for_dist, t, x0, x1_gen, x1_data, y, alpha , gamma, False)
                    
                    if with_gan_loss:
                        gan_loss = GANloss(u,  x1_gen, x1_data, y, False, t_gan, x0) 
                        loss = loss + disc_coef * gan_loss 
                        if is_log:
                            # wandb.log({"GAN loss": gan_loss.detach().cpu().numpy(), "loss": loss.detach().cpu().numpy()})
                            logger.add_scalar(step, "GAN loss", gan_loss.detach().cpu().numpy())
                            logger.add_scalar(step, "loss", loss.detach().cpu().numpy())
                    else:
                        if is_log:
                            logger.add_scalar(step, "loss", loss.detach().cpu().numpy())
                        
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(u.parameters(), grad_clip)  

                    optim_u.step()
                    sched_u.step()
                    generator.train()


            
            
            # if not with_gan_loss:
            #     if is_log:
            #         # wandb.log({"loss": loss.detach().cpu().numpy()})
            #         logger.add_scalar(step, "loss", loss.detach().cpu().numpy())
                
            # sample and Saving the weights
            if save_image_step > 0 and step % save_image_step == 0:
                with torch.no_grad():
                    
                    generate_samples(generator,  savedir, step, net_="normal", log = is_log)

                    generate_samples(ema_gen, savedir, step, net_="ema", log = is_log)
                    
            if (step) % fid_step == 0: 
                        
                        fid_score = []
                        for _ in range(3): 
                            fid_score.append(eval_model(generator))
                        fids.append(np.mean(fid_score))
                        
                        with open(savedir + f"fids.pkl", 'wb') as file:
                            pickle.dump(fids, file)
                            
                        if is_log:
                            # wandb.log({"fid": np.mean(fid_score)})
                            logger.add_scalar(step, "fid", np.mean(fid_score))
                        
                        
                        
                        fid_score = []
                        for _ in range(3):
                            fid_score.append(eval_model(ema_gen))
                            
                        ema_fids.append(np.mean(fid_score))
                        ema_fids_std.append(np.std(fid_score))
                        
                        with open(savedir + f"ema_fids.pkl", 'wb') as file:
                            pickle.dump(ema_fids, file)
                            
                        with open(savedir + f"ema_fids_std.pkl", 'wb') as file:
                            pickle.dump(ema_fids_std, file)
                            
                        if is_log:
                            # wandb.log({"ema fid": np.mean(fid_score)})
                            logger.add_scalar(step, "ema fid", np.mean(fid_score))
            
            if save_model_step > 0 and step % save_model_step == 0:
                        torch.save(
                                {
                                    "gen": generator.state_dict(),
                                    "u": u.state_dict(),
                                    "ema_gen": ema_gen.state_dict(),
                                    "sched": sched_gen.state_dict(),
                                    "optim_gen": optim_gen.state_dict(),
                                    "optim_u": optim_u.state_dict(),
                                    "step": step,
                                },
                                savedir + f"cifar10_dist_step_{step}.pt",
                        )
