import torch
import torch.nn as nn
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from models.utils import evaluate_model_v2
import wandb
import os
from tempfile import NamedTemporaryFile
from torch.optim.lr_scheduler import CosineAnnealingLR



class Noise2NoiseEncoder(pl.LightningModule):
    def __init__(self, fix_model, finetune_model, trainable_layers, dataset, output_dir, samples_to_evaluate):
        super().__init__()
        # fix model is the point2sdf model
        self.fix_model = fix_model.to('cpu')
        # finetune model is the denoising model
        self.model = finetune_model
        self.criterion = nn.MSELoss()
        self.dataset = dataset
        self.output_dir = output_dir
        self.samples_to_evaluate = samples_to_evaluate


        for param in self.fix_model.parameters():
            param.requires_grad = False

        for name, param in self.model.named_parameters():
            param.requires_grad = any(layer in name for layer in trainable_layers)

        # validate all the unfrozen layers
        print("unfrozen layers for finetuning:")
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print(name)

        self.best_iou = 0.0
        self.best_normal_consistency = 0.0
        self.best_cd = 1000

    def training_step(self, batch, batch_idx):

        _, noise1, noise2, queries = batch

        with torch.no_grad():
            s1 = self.fix_model(noise1, queries)['logits'].cuda()

        s2 = self.model(noise2, queries)['logits']

        loss = self.criterion(s1, s2)

        self.log('train_loss', loss, on_step=True, on_epoch=True)

        return loss


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=1e-4)
        return optimizer



    def on_validation_start(self):

        if torch.distributed.is_initialized():
            torch.distributed.barrier()


    def validation_step(self, batch, batch_idx):
        return None

    def on_validation_epoch_end(self):

        datamodule = self.trainer.datamodule
        if not hasattr(datamodule, 'dataset') or len(datamodule.dataset) == 0:
            raise ValueError("data module dataset is empty or not set")

        if not hasattr(datamodule.dataset, 'val_model_list'):
            raise AttributeError("Dataset missing required val_model_list attribute")


        if self.trainer.is_global_zero: 
            epoch = self.current_epoch
            save_path = f'{self.output_dir}/mesh_{epoch}.obj'

            meshs, metrics = evaluate_model_v2(
                self.fix_model,
                self.model,
                self.device,
                self.dataset,
                save_path,
                use_train_data=False,
                use_one_shape_only=False,
                num_samples=None
            )

            current_iou = metrics['denoised']['iou']
            current_nc = metrics['denoised']['normal_consistency']
            current_cd = metrics['cd']

            if current_iou > self.best_iou:
                self.best_iou = current_iou
                checkpoint_path = os.path.join(self.output_dir, f'best_iou_model_{self.best_iou:.4f}.pth')
                torch.save({
                    'epoch': epoch,
                    'model': self.model.state_dict(),
                    'iou': current_iou,
                    'normal_consistency': current_nc
                }, checkpoint_path)
                print(f"best , Normal Consistency: {current_nc:.4f}")


            if current_nc > self.best_normal_consistency:
                self.best_normal_consistency = current_nc
                checkpoint_path = os.path.join(self.output_dir, f'best_normal_consistency_model_{self.best_normal_consistency:.4f}_{self.best_cd}_{epoch}.pth')
                torch.save({
                    'epoch': epoch,
                    'model': self.model.state_dict(),
                    'iou': current_iou,
                    'normal_consistency': current_nc
                }, checkpoint_path)
                print(f"best normal, Normal Consistency: {current_nc:.4f}")


            mm = ["mse", "mae", "iou", "normal_consistency"]
            log_data = {}
            for metric in mm:
                log_data[f"metrics/{metric}/denoised"] = metrics['denoised'][metric]
                log_data[f"metrics/{metric}/noise"] = metrics['noise'][metric]
                log_data[f"metrics/{metric}/gt"] = metrics['gt'][metric]

                log_data[f"metrics/{metric}/denoised_delta"] = metrics['denoised'][metric] - metrics['noise'][metric]
                log_data[f"metrics/{metric}/gt_delta"] = metrics['gt'][metric] - metrics['noise'][metric]



            self.log_dict(log_data, sync_dist=False)
            # print all mesh indices
            mesh_indices = [meshObject['index'] for meshObject in meshs]
            print(f"mesh indices: {mesh_indices}")


            for i, meshObject in enumerate(meshs):
                with NamedTemporaryFile(suffix='.obj', delete=True) as tmpfile:
                    meshObject['denoised_mesh'].export(tmpfile.name)
                    self.logger.experiment.log({
                        f"mesh/epoch{epoch}_sample{meshObject['index']}_denoised": wandb.Object3D(tmpfile.name)
                    })

                    with NamedTemporaryFile(suffix='.obj', delete=True) as tmpfile:
                        meshObject['noise_mesh'].export(tmpfile.name)
                        self.logger.experiment.log({
                            f"mesh/epoch{epoch}_sample{meshObject['index']}_noise": wandb.Object3D(tmpfile.name)
                    })

                    with NamedTemporaryFile(suffix='.obj', delete=True) as tmpfile:
                        meshObject['gt_mesh'].export(tmpfile.name)
                        self.logger.experiment.log({
                            f"mesh/epoch{epoch}_sample{meshObject['index']}_gt": wandb.Object3D(tmpfile.name)
                    })
        else:
            self.log_dict({}, sync_dist=True)
